mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-14 09:38:05 +00:00
Compare commits
1 Commits
v0.16.3
...
pysssss/no
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ecec1310b2 |
127
.coderabbit.yaml
127
.coderabbit.yaml
@@ -1,127 +0,0 @@
|
|||||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
|
||||||
language: "en-US"
|
|
||||||
early_access: false
|
|
||||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
|
||||||
|
|
||||||
reviews:
|
|
||||||
profile: "chill"
|
|
||||||
request_changes_workflow: false
|
|
||||||
high_level_summary: false
|
|
||||||
poem: false
|
|
||||||
review_status: false
|
|
||||||
review_details: false
|
|
||||||
commit_status: true
|
|
||||||
collapse_walkthrough: true
|
|
||||||
changed_files_summary: false
|
|
||||||
sequence_diagrams: false
|
|
||||||
estimate_code_review_effort: false
|
|
||||||
assess_linked_issues: false
|
|
||||||
related_issues: false
|
|
||||||
related_prs: false
|
|
||||||
suggested_labels: false
|
|
||||||
auto_apply_labels: false
|
|
||||||
suggested_reviewers: false
|
|
||||||
auto_assign_reviewers: false
|
|
||||||
in_progress_fortune: false
|
|
||||||
enable_prompt_for_ai_agents: true
|
|
||||||
|
|
||||||
path_filters:
|
|
||||||
- "!comfy_api_nodes/apis/**"
|
|
||||||
- "!**/generated/*.pyi"
|
|
||||||
- "!.ci/**"
|
|
||||||
- "!script_examples/**"
|
|
||||||
- "!**/__pycache__/**"
|
|
||||||
- "!**/*.ipynb"
|
|
||||||
- "!**/*.png"
|
|
||||||
- "!**/*.bat"
|
|
||||||
|
|
||||||
path_instructions:
|
|
||||||
- path: "**"
|
|
||||||
instructions: |
|
|
||||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
|
||||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
|
||||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
|
||||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
|
||||||
treat it as unchanged. Contributors should not feel obligated to address
|
|
||||||
pre-existing issues outside the scope of their contribution.
|
|
||||||
- path: "comfy/**"
|
|
||||||
instructions: |
|
|
||||||
Core ML/diffusion engine. Focus on:
|
|
||||||
- Backward compatibility (breaking changes affect all custom nodes)
|
|
||||||
- Memory management and GPU resource handling
|
|
||||||
- Performance implications in hot paths
|
|
||||||
- Thread safety for concurrent execution
|
|
||||||
- path: "comfy_api_nodes/**"
|
|
||||||
instructions: |
|
|
||||||
Third-party API integration nodes. Focus on:
|
|
||||||
- No hardcoded API keys or secrets
|
|
||||||
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
|
||||||
- Correct Pydantic model usage
|
|
||||||
- Security of user data passed to external APIs
|
|
||||||
- path: "comfy_extras/**"
|
|
||||||
instructions: |
|
|
||||||
Community-contributed extra nodes. Focus on:
|
|
||||||
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
|
||||||
- No breaking changes to existing node interfaces
|
|
||||||
- path: "comfy_execution/**"
|
|
||||||
instructions: |
|
|
||||||
Execution engine (graph execution, caching, jobs). Focus on:
|
|
||||||
- Caching correctness
|
|
||||||
- Concurrent execution safety
|
|
||||||
- Graph validation edge cases
|
|
||||||
- path: "nodes.py"
|
|
||||||
instructions: |
|
|
||||||
Core node definitions (2500+ lines). Focus on:
|
|
||||||
- Backward compatibility of NODE_CLASS_MAPPINGS
|
|
||||||
- Consistency of INPUT_TYPES return format
|
|
||||||
- path: "alembic_db/**"
|
|
||||||
instructions: |
|
|
||||||
Database migrations. Focus on:
|
|
||||||
- Migration safety and rollback support
|
|
||||||
- Data preservation during schema changes
|
|
||||||
|
|
||||||
auto_review:
|
|
||||||
enabled: true
|
|
||||||
auto_incremental_review: true
|
|
||||||
drafts: false
|
|
||||||
ignore_title_keywords:
|
|
||||||
- "WIP"
|
|
||||||
- "DO NOT REVIEW"
|
|
||||||
- "DO NOT MERGE"
|
|
||||||
|
|
||||||
finishing_touches:
|
|
||||||
docstrings:
|
|
||||||
enabled: false
|
|
||||||
unit_tests:
|
|
||||||
enabled: false
|
|
||||||
|
|
||||||
tools:
|
|
||||||
ruff:
|
|
||||||
enabled: false
|
|
||||||
pylint:
|
|
||||||
enabled: false
|
|
||||||
flake8:
|
|
||||||
enabled: false
|
|
||||||
gitleaks:
|
|
||||||
enabled: true
|
|
||||||
shellcheck:
|
|
||||||
enabled: false
|
|
||||||
markdownlint:
|
|
||||||
enabled: false
|
|
||||||
yamllint:
|
|
||||||
enabled: false
|
|
||||||
languagetool:
|
|
||||||
enabled: false
|
|
||||||
github-checks:
|
|
||||||
enabled: true
|
|
||||||
timeout_ms: 90000
|
|
||||||
ast-grep:
|
|
||||||
essential_rules: true
|
|
||||||
|
|
||||||
chat:
|
|
||||||
auto_reply: true
|
|
||||||
|
|
||||||
knowledge_base:
|
|
||||||
opt_out: false
|
|
||||||
learnings:
|
|
||||||
scope: "auto"
|
|
||||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,7 +16,7 @@ body:
|
|||||||
|
|
||||||
## Very Important
|
## Very Important
|
||||||
|
|
||||||
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: custom-nodes-test
|
id: custom-nodes-test
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
|||||||
/.vs
|
/.vs
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv*/
|
venv/
|
||||||
.venv/
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
|
|||||||
@@ -189,6 +189,8 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
|||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||||
|
|
||||||
|
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
@@ -227,9 +229,9 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||||
|
|
||||||
|
|
||||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from importlib.metadata import version
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
|
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
import app.logger
|
import app.logger
|
||||||
@@ -45,7 +45,25 @@ def get_installed_frontend_version():
|
|||||||
|
|
||||||
|
|
||||||
def get_required_frontend_version():
|
def get_required_frontend_version():
|
||||||
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
"""Get the required frontend version from requirements.txt."""
|
||||||
|
try:
|
||||||
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("comfyui-frontend-package=="):
|
||||||
|
version_str = line.split("==")[-1]
|
||||||
|
if not is_valid_version(version_str):
|
||||||
|
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||||
|
return None
|
||||||
|
return version_str
|
||||||
|
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||||
|
return None
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading requirements.txt: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_frontend_version():
|
def check_frontend_version():
|
||||||
@@ -199,7 +217,25 @@ class FrontendManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_templates_version(cls) -> str:
|
def get_required_templates_version(cls) -> str:
|
||||||
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
"""Get the required workflow templates version from requirements.txt."""
|
||||||
|
try:
|
||||||
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("comfyui-workflow-templates=="):
|
||||||
|
version_str = line.split("==")[-1]
|
||||||
|
if not is_valid_version(version_str):
|
||||||
|
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||||
|
return None
|
||||||
|
return version_str
|
||||||
|
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||||
|
return None
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading requirements.txt: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frontend_path(cls) -> str:
|
def default_frontend_path(cls) -> str:
|
||||||
|
|||||||
@@ -46,8 +46,6 @@ class NodeReplaceManager:
|
|||||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||||
need_replacement: set[str] = set()
|
need_replacement: set[str] = set()
|
||||||
for node_number, node_struct in prompt.items():
|
for node_number, node_struct in prompt.items():
|
||||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
|
||||||
continue
|
|
||||||
class_type = node_struct["class_type"]
|
class_type = node_struct["class_type"]
|
||||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class SubgraphManager:
|
|||||||
return entry_id, entry
|
return entry_id, entry
|
||||||
|
|
||||||
async def load_entry_data(self, entry: SubgraphEntry):
|
async def load_entry_data(self, entry: SubgraphEntry):
|
||||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
with open(entry['path'], 'r') as f:
|
||||||
entry['data'] = f.read()
|
entry['data'] = f.read()
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform float u_float0; // Brightness slider -100..100
|
|
||||||
uniform float u_float1; // Contrast slider -100..100
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
const float MID_GRAY = 0.18; // 18% reflectance
|
|
||||||
|
|
||||||
// sRGB gamma 2.2 approximation
|
|
||||||
vec3 srgbToLinear(vec3 c) {
|
|
||||||
return pow(max(c, 0.0), vec3(2.2));
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 linearToSrgb(vec3 c) {
|
|
||||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
|
||||||
}
|
|
||||||
|
|
||||||
float mapBrightness(float b) {
|
|
||||||
return clamp(b / 100.0, -1.0, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
float mapContrast(float c) {
|
|
||||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 orig = texture(u_image0, v_texCoord);
|
|
||||||
|
|
||||||
float brightness = mapBrightness(u_float0);
|
|
||||||
float contrast = mapContrast(u_float1);
|
|
||||||
|
|
||||||
vec3 lin = srgbToLinear(orig.rgb);
|
|
||||||
|
|
||||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
|
||||||
|
|
||||||
// Convert back to sRGB
|
|
||||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
|
||||||
|
|
||||||
fragColor = vec4(result, orig.a);
|
|
||||||
}
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform vec2 u_resolution;
|
|
||||||
uniform int u_int0; // Mode
|
|
||||||
uniform float u_float0; // Amount (0 to 100)
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
const int MODE_LINEAR = 0;
|
|
||||||
const int MODE_RADIAL = 1;
|
|
||||||
const int MODE_BARREL = 2;
|
|
||||||
const int MODE_SWIRL = 3;
|
|
||||||
const int MODE_DIAGONAL = 4;
|
|
||||||
|
|
||||||
const float AMOUNT_SCALE = 0.0005;
|
|
||||||
const float RADIAL_MULT = 4.0;
|
|
||||||
const float BARREL_MULT = 8.0;
|
|
||||||
const float INV_SQRT2 = 0.70710678118;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec2 uv = v_texCoord;
|
|
||||||
vec4 original = texture(u_image0, uv);
|
|
||||||
|
|
||||||
float amount = u_float0 * AMOUNT_SCALE;
|
|
||||||
|
|
||||||
if (amount < 0.000001) {
|
|
||||||
fragColor = original;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Aspect-corrected coordinates for circular effects
|
|
||||||
float aspect = u_resolution.x / u_resolution.y;
|
|
||||||
vec2 centered = uv - 0.5;
|
|
||||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
|
||||||
float r = length(corrected);
|
|
||||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
|
||||||
vec2 offset = vec2(0.0);
|
|
||||||
|
|
||||||
if (u_int0 == MODE_LINEAR) {
|
|
||||||
// Horizontal shift (no aspect correction needed)
|
|
||||||
offset = vec2(amount, 0.0);
|
|
||||||
}
|
|
||||||
else if (u_int0 == MODE_RADIAL) {
|
|
||||||
// Outward from center, stronger at edges
|
|
||||||
offset = dir * r * amount * RADIAL_MULT;
|
|
||||||
offset.x /= aspect; // Convert back to UV space
|
|
||||||
}
|
|
||||||
else if (u_int0 == MODE_BARREL) {
|
|
||||||
// Lens distortion simulation (r² falloff)
|
|
||||||
offset = dir * r * r * amount * BARREL_MULT;
|
|
||||||
offset.x /= aspect; // Convert back to UV space
|
|
||||||
}
|
|
||||||
else if (u_int0 == MODE_SWIRL) {
|
|
||||||
// Perpendicular to radial (rotational aberration)
|
|
||||||
vec2 perp = vec2(-dir.y, dir.x);
|
|
||||||
offset = perp * r * amount * RADIAL_MULT;
|
|
||||||
offset.x /= aspect; // Convert back to UV space
|
|
||||||
}
|
|
||||||
else if (u_int0 == MODE_DIAGONAL) {
|
|
||||||
// 45° offset (no aspect correction needed)
|
|
||||||
offset = vec2(amount, amount) * INV_SQRT2;
|
|
||||||
}
|
|
||||||
|
|
||||||
float red = texture(u_image0, uv + offset).r;
|
|
||||||
float green = original.g;
|
|
||||||
float blue = texture(u_image0, uv - offset).b;
|
|
||||||
|
|
||||||
fragColor = vec4(red, green, blue, original.a);
|
|
||||||
}
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform float u_float0; // temperature (-100 to 100)
|
|
||||||
uniform float u_float1; // tint (-100 to 100)
|
|
||||||
uniform float u_float2; // vibrance (-100 to 100)
|
|
||||||
uniform float u_float3; // saturation (-100 to 100)
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
const float INPUT_SCALE = 0.01;
|
|
||||||
const float TEMP_TINT_PRIMARY = 0.3;
|
|
||||||
const float TEMP_TINT_SECONDARY = 0.15;
|
|
||||||
const float VIBRANCE_BOOST = 2.0;
|
|
||||||
const float SATURATION_BOOST = 2.0;
|
|
||||||
const float SKIN_PROTECTION = 0.5;
|
|
||||||
const float EPSILON = 0.001;
|
|
||||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 tex = texture(u_image0, v_texCoord);
|
|
||||||
vec3 color = tex.rgb;
|
|
||||||
|
|
||||||
// Scale inputs: -100/100 → -1/1
|
|
||||||
float temperature = u_float0 * INPUT_SCALE;
|
|
||||||
float tint = u_float1 * INPUT_SCALE;
|
|
||||||
float vibrance = u_float2 * INPUT_SCALE;
|
|
||||||
float saturation = u_float3 * INPUT_SCALE;
|
|
||||||
|
|
||||||
// Temperature (warm/cool): positive = warm, negative = cool
|
|
||||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
|
||||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
|
||||||
|
|
||||||
// Tint (green/magenta): positive = green, negative = magenta
|
|
||||||
color.g += tint * TEMP_TINT_PRIMARY;
|
|
||||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
|
||||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
|
||||||
|
|
||||||
// Single clamp after temperature/tint
|
|
||||||
color = clamp(color, 0.0, 1.0);
|
|
||||||
|
|
||||||
// Vibrance with skin protection
|
|
||||||
if (vibrance != 0.0) {
|
|
||||||
float maxC = max(color.r, max(color.g, color.b));
|
|
||||||
float minC = min(color.r, min(color.g, color.b));
|
|
||||||
float sat = maxC - minC;
|
|
||||||
float gray = dot(color, LUMA_WEIGHTS);
|
|
||||||
|
|
||||||
if (vibrance < 0.0) {
|
|
||||||
// Desaturate: -100 → gray
|
|
||||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
|
||||||
} else {
|
|
||||||
// Boost less saturated colors more
|
|
||||||
float vibranceAmt = vibrance * (1.0 - sat);
|
|
||||||
|
|
||||||
// Branchless skin tone protection
|
|
||||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
|
||||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
|
||||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
|
||||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
|
||||||
|
|
||||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Saturation
|
|
||||||
if (saturation != 0.0) {
|
|
||||||
float gray = dot(color, LUMA_WEIGHTS);
|
|
||||||
float satMix = saturation < 0.0
|
|
||||||
? 1.0 + saturation // -100 → gray
|
|
||||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
|
||||||
color = mix(vec3(gray), color, satMix);
|
|
||||||
}
|
|
||||||
|
|
||||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
|
||||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
|
||||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
const int MAX_RADIUS = 20;
|
|
||||||
const float EPSILON = 0.0001;
|
|
||||||
|
|
||||||
// Perceptual luminance
|
|
||||||
float getLuminance(vec3 rgb) {
|
|
||||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
|
||||||
}
|
|
||||||
|
|
||||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
|
||||||
float sigmaSpatial, float sigmaColor)
|
|
||||||
{
|
|
||||||
vec4 center = texture(u_image0, uv);
|
|
||||||
vec3 centerRGB = center.rgb;
|
|
||||||
|
|
||||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
|
||||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
|
||||||
|
|
||||||
vec3 sumRGB = vec3(0.0);
|
|
||||||
float sumWeight = 0.0;
|
|
||||||
|
|
||||||
int step = max(u_int0, 1);
|
|
||||||
float radius2 = float(radius * radius);
|
|
||||||
|
|
||||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
|
||||||
if (dy < -radius || dy > radius) continue;
|
|
||||||
if (abs(dy) % step != 0) continue;
|
|
||||||
|
|
||||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
|
||||||
if (dx < -radius || dx > radius) continue;
|
|
||||||
if (abs(dx) % step != 0) continue;
|
|
||||||
|
|
||||||
vec2 offset = vec2(float(dx), float(dy));
|
|
||||||
float dist2 = dot(offset, offset);
|
|
||||||
if (dist2 > radius2) continue;
|
|
||||||
|
|
||||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
|
||||||
|
|
||||||
// Spatial Gaussian
|
|
||||||
float spatialWeight = exp(dist2 * invSpatial2);
|
|
||||||
|
|
||||||
// Perceptual color distance (weighted RGB)
|
|
||||||
vec3 diff = sampleRGB - centerRGB;
|
|
||||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
|
||||||
float colorWeight = exp(colorDist * invColor2);
|
|
||||||
|
|
||||||
float w = spatialWeight * colorWeight;
|
|
||||||
sumRGB += sampleRGB * w;
|
|
||||||
sumWeight += w;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
|
||||||
return vec4(resultRGB, center.a); // preserve center alpha
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
|
||||||
|
|
||||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
|
||||||
int radius = int(radiusF + 0.5);
|
|
||||||
|
|
||||||
if (radius == 0) {
|
|
||||||
fragColor = texture(u_image0, v_texCoord);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Edge threshold → color sigma
|
|
||||||
// Squared curve for better low-end control
|
|
||||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
|
||||||
t *= t;
|
|
||||||
float sigmaColor = mix(0.01, 0.5, t);
|
|
||||||
|
|
||||||
// Spatial sigma tied to radius
|
|
||||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
|
||||||
|
|
||||||
fragColor = bilateralFilter(
|
|
||||||
v_texCoord,
|
|
||||||
texelSize,
|
|
||||||
radius,
|
|
||||||
sigmaSpatial,
|
|
||||||
sigmaColor
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform vec2 u_resolution;
|
|
||||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
|
||||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
|
||||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
|
||||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
|
||||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
layout(location = 0) out vec4 fragColor0;
|
|
||||||
|
|
||||||
// High-quality integer hash (pcg-like)
|
|
||||||
uint pcg(uint v) {
|
|
||||||
uint state = v * 747796405u + 2891336453u;
|
|
||||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
|
||||||
return (word >> 22u) ^ word;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2D -> 1D hash input
|
|
||||||
uint hash2d(uvec2 p) {
|
|
||||||
return pcg(p.x + pcg(p.y));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash to float [0, 1]
|
|
||||||
float hashf(uvec2 p) {
|
|
||||||
return float(hash2d(p)) / float(0xffffffffu);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash to float with offset (for RGB channels)
|
|
||||||
float hashf(uvec2 p, uint offset) {
|
|
||||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
|
||||||
// Using simple approximation: average of multiple samples
|
|
||||||
float toGaussian(uvec2 p) {
|
|
||||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
|
||||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
|
||||||
}
|
|
||||||
|
|
||||||
float toGaussian(uvec2 p, uint offset) {
|
|
||||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
|
||||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
|
||||||
return (sum - 2.0) * 0.7;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Smooth noise with better interpolation
|
|
||||||
float smoothNoise(vec2 p) {
|
|
||||||
vec2 i = floor(p);
|
|
||||||
vec2 f = fract(p);
|
|
||||||
|
|
||||||
// Quintic interpolation (less banding than cubic)
|
|
||||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
|
||||||
|
|
||||||
uvec2 ui = uvec2(i);
|
|
||||||
float a = toGaussian(ui);
|
|
||||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
|
||||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
|
||||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
|
||||||
|
|
||||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
float smoothNoise(vec2 p, uint offset) {
|
|
||||||
vec2 i = floor(p);
|
|
||||||
vec2 f = fract(p);
|
|
||||||
|
|
||||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
|
||||||
|
|
||||||
uvec2 ui = uvec2(i);
|
|
||||||
float a = toGaussian(ui, offset);
|
|
||||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
|
||||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
|
||||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
|
||||||
|
|
||||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 color = texture(u_image0, v_texCoord);
|
|
||||||
|
|
||||||
// Luminance (Rec.709)
|
|
||||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
|
||||||
|
|
||||||
// Grain UV (resolution-independent)
|
|
||||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
|
||||||
uvec2 grainPixel = uvec2(grainUV);
|
|
||||||
|
|
||||||
float g;
|
|
||||||
vec3 grainRGB;
|
|
||||||
|
|
||||||
if (u_int0 == 1) {
|
|
||||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
|
||||||
g = toGaussian(grainPixel);
|
|
||||||
grainRGB = vec3(
|
|
||||||
toGaussian(grainPixel, 100u),
|
|
||||||
toGaussian(grainPixel, 200u),
|
|
||||||
toGaussian(grainPixel, 300u)
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// Smooth mode: interpolated with quintic curve
|
|
||||||
g = smoothNoise(grainUV);
|
|
||||||
grainRGB = vec3(
|
|
||||||
smoothNoise(grainUV, 100u),
|
|
||||||
smoothNoise(grainUV, 200u),
|
|
||||||
smoothNoise(grainUV, 300u)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Luminance weighting (less grain in highlights)
|
|
||||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
|
||||||
|
|
||||||
// Strength
|
|
||||||
float strength = u_float0 * 0.15;
|
|
||||||
|
|
||||||
// Color vs monochrome grain
|
|
||||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
|
||||||
|
|
||||||
color.rgb += grainColor * strength * lumWeight;
|
|
||||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
|
||||||
}
|
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision mediump float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform vec2 u_resolution;
|
|
||||||
uniform int u_int0; // Blend mode
|
|
||||||
uniform int u_int1; // Color tint
|
|
||||||
uniform float u_float0; // Intensity
|
|
||||||
uniform float u_float1; // Radius
|
|
||||||
uniform float u_float2; // Threshold
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
const int BLEND_ADD = 0;
|
|
||||||
const int BLEND_SCREEN = 1;
|
|
||||||
const int BLEND_SOFT = 2;
|
|
||||||
const int BLEND_OVERLAY = 3;
|
|
||||||
const int BLEND_LIGHTEN = 4;
|
|
||||||
|
|
||||||
const float GOLDEN_ANGLE = 2.39996323;
|
|
||||||
const int MAX_SAMPLES = 48;
|
|
||||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
|
||||||
|
|
||||||
float hash(vec2 p) {
|
|
||||||
p = fract(p * vec2(123.34, 456.21));
|
|
||||||
p += dot(p, p + 45.32);
|
|
||||||
return fract(p.x * p.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 hexToRgb(int h) {
|
|
||||||
return vec3(
|
|
||||||
float((h >> 16) & 255),
|
|
||||||
float((h >> 8) & 255),
|
|
||||||
float(h & 255)
|
|
||||||
) * (1.0 / 255.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
|
||||||
if (mode == BLEND_SCREEN) {
|
|
||||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
|
||||||
}
|
|
||||||
if (mode == BLEND_SOFT) {
|
|
||||||
return mix(
|
|
||||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
|
||||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
|
||||||
step(0.5, glow)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (mode == BLEND_OVERLAY) {
|
|
||||||
return mix(
|
|
||||||
2.0 * base * glow,
|
|
||||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
|
||||||
step(0.5, base)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (mode == BLEND_LIGHTEN) {
|
|
||||||
return max(base, glow);
|
|
||||||
}
|
|
||||||
return base + glow;
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 original = texture(u_image0, v_texCoord);
|
|
||||||
|
|
||||||
float intensity = u_float0 * 0.05;
|
|
||||||
float radius = u_float1 * u_float1 * 0.012;
|
|
||||||
|
|
||||||
if (intensity < 0.001 || radius < 0.1) {
|
|
||||||
fragColor = original;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
float threshold = 1.0 - u_float2 * 0.01;
|
|
||||||
float t0 = threshold - 0.15;
|
|
||||||
float t1 = threshold + 0.15;
|
|
||||||
|
|
||||||
vec2 texelSize = 1.0 / u_resolution;
|
|
||||||
float radius2 = radius * radius;
|
|
||||||
|
|
||||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
|
||||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
|
||||||
|
|
||||||
float noise = hash(gl_FragCoord.xy);
|
|
||||||
float angleOffset = noise * GOLDEN_ANGLE;
|
|
||||||
float radiusJitter = 0.85 + noise * 0.3;
|
|
||||||
|
|
||||||
float ca = cos(GOLDEN_ANGLE);
|
|
||||||
float sa = sin(GOLDEN_ANGLE);
|
|
||||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
|
||||||
|
|
||||||
vec3 glow = vec3(0.0);
|
|
||||||
float totalWeight = 0.0;
|
|
||||||
|
|
||||||
// Center tap
|
|
||||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
|
||||||
glow += original.rgb * centerMask * 2.0;
|
|
||||||
totalWeight += 2.0;
|
|
||||||
|
|
||||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
|
||||||
if (i >= samples) break;
|
|
||||||
|
|
||||||
float fi = float(i);
|
|
||||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
|
||||||
|
|
||||||
vec2 offset = dir * dist * texelSize;
|
|
||||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
|
||||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
|
||||||
|
|
||||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
|
||||||
w = max(w, 0.0);
|
|
||||||
w *= w;
|
|
||||||
|
|
||||||
glow += c * mask * w;
|
|
||||||
totalWeight += w;
|
|
||||||
|
|
||||||
dir = vec2(
|
|
||||||
dir.x * ca - dir.y * sa,
|
|
||||||
dir.x * sa + dir.y * ca
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
glow *= intensity / max(totalWeight, 0.001);
|
|
||||||
|
|
||||||
if (u_int1 > 0) {
|
|
||||||
glow *= hexToRgb(u_int1);
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 result = blend(original.rgb, glow, u_int0);
|
|
||||||
result += (noise - 0.5) * (1.0 / 255.0);
|
|
||||||
|
|
||||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
|
||||||
}
|
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
|
||||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
|
||||||
uniform float u_float0; // Hue (-180 to 180)
|
|
||||||
uniform float u_float1; // Saturation (-100 to 100)
|
|
||||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
|
||||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
// Color range modes
|
|
||||||
const int MODE_MASTER = 0;
|
|
||||||
const int MODE_RED = 1;
|
|
||||||
const int MODE_YELLOW = 2;
|
|
||||||
const int MODE_GREEN = 3;
|
|
||||||
const int MODE_CYAN = 4;
|
|
||||||
const int MODE_BLUE = 5;
|
|
||||||
const int MODE_MAGENTA = 6;
|
|
||||||
const int MODE_COLORIZE = 7;
|
|
||||||
|
|
||||||
// Color space modes
|
|
||||||
const int COLORSPACE_HSL = 0;
|
|
||||||
const int COLORSPACE_HSB = 1;
|
|
||||||
|
|
||||||
const float EPSILON = 0.0001;
|
|
||||||
|
|
||||||
//=============================================================================
|
|
||||||
// RGB <-> HSL Conversions
|
|
||||||
//=============================================================================
|
|
||||||
|
|
||||||
vec3 rgb2hsl(vec3 c) {
|
|
||||||
float maxC = max(max(c.r, c.g), c.b);
|
|
||||||
float minC = min(min(c.r, c.g), c.b);
|
|
||||||
float delta = maxC - minC;
|
|
||||||
|
|
||||||
float h = 0.0;
|
|
||||||
float s = 0.0;
|
|
||||||
float l = (maxC + minC) * 0.5;
|
|
||||||
|
|
||||||
if (delta > EPSILON) {
|
|
||||||
s = l < 0.5
|
|
||||||
? delta / (maxC + minC)
|
|
||||||
: delta / (2.0 - maxC - minC);
|
|
||||||
|
|
||||||
if (maxC == c.r) {
|
|
||||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
|
||||||
} else if (maxC == c.g) {
|
|
||||||
h = (c.b - c.r) / delta + 2.0;
|
|
||||||
} else {
|
|
||||||
h = (c.r - c.g) / delta + 4.0;
|
|
||||||
}
|
|
||||||
h /= 6.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return vec3(h, s, l);
|
|
||||||
}
|
|
||||||
|
|
||||||
float hue2rgb(float p, float q, float t) {
|
|
||||||
t = fract(t);
|
|
||||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
|
||||||
if (t < 0.5) return q;
|
|
||||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 hsl2rgb(vec3 hsl) {
|
|
||||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
|
||||||
|
|
||||||
float q = hsl.z < 0.5
|
|
||||||
? hsl.z * (1.0 + hsl.y)
|
|
||||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
|
||||||
float p = 2.0 * hsl.z - q;
|
|
||||||
|
|
||||||
return vec3(
|
|
||||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
|
||||||
hue2rgb(p, q, hsl.x),
|
|
||||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 rgb2hsb(vec3 c) {
|
|
||||||
float maxC = max(max(c.r, c.g), c.b);
|
|
||||||
float minC = min(min(c.r, c.g), c.b);
|
|
||||||
float delta = maxC - minC;
|
|
||||||
|
|
||||||
float h = 0.0;
|
|
||||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
|
||||||
float b = maxC;
|
|
||||||
|
|
||||||
if (delta > EPSILON) {
|
|
||||||
if (maxC == c.r) {
|
|
||||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
|
||||||
} else if (maxC == c.g) {
|
|
||||||
h = (c.b - c.r) / delta + 2.0;
|
|
||||||
} else {
|
|
||||||
h = (c.r - c.g) / delta + 4.0;
|
|
||||||
}
|
|
||||||
h /= 6.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return vec3(h, s, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 hsb2rgb(vec3 hsb) {
|
|
||||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
|
||||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
//=============================================================================
|
|
||||||
// Color Range Weight Calculation
|
|
||||||
//=============================================================================
|
|
||||||
|
|
||||||
float hueDistance(float a, float b) {
|
|
||||||
float d = abs(a - b);
|
|
||||||
return min(d, 1.0 - d);
|
|
||||||
}
|
|
||||||
|
|
||||||
float getHueWeight(float hue, float center, float overlap) {
|
|
||||||
float baseWidth = 1.0 / 6.0;
|
|
||||||
float feather = baseWidth * overlap;
|
|
||||||
|
|
||||||
float d = hueDistance(hue, center);
|
|
||||||
|
|
||||||
float inner = baseWidth * 0.5;
|
|
||||||
float outer = inner + feather;
|
|
||||||
|
|
||||||
return 1.0 - smoothstep(inner, outer, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
float getModeWeight(float hue, int mode, float overlap) {
|
|
||||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
|
||||||
|
|
||||||
if (mode == MODE_RED) {
|
|
||||||
return max(
|
|
||||||
getHueWeight(hue, 0.0, overlap),
|
|
||||||
getHueWeight(hue, 1.0, overlap)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
float center = float(mode - 1) / 6.0;
|
|
||||||
return getHueWeight(hue, center, overlap);
|
|
||||||
}
|
|
||||||
|
|
||||||
//=============================================================================
|
|
||||||
// Adjustment Functions
|
|
||||||
//=============================================================================
|
|
||||||
|
|
||||||
float adjustLightness(float l, float amount) {
|
|
||||||
return amount > 0.0
|
|
||||||
? l + (1.0 - l) * amount
|
|
||||||
: l + l * amount;
|
|
||||||
}
|
|
||||||
|
|
||||||
float adjustBrightness(float b, float amount) {
|
|
||||||
return clamp(b + amount, 0.0, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
float adjustSaturation(float s, float amount) {
|
|
||||||
return amount > 0.0
|
|
||||||
? s + (1.0 - s) * amount
|
|
||||||
: s + s * amount;
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
|
||||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
|
||||||
float l = adjustLightness(lum, light);
|
|
||||||
|
|
||||||
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
|
||||||
return hsl2rgb(hsl);
|
|
||||||
}
|
|
||||||
|
|
||||||
//=============================================================================
|
|
||||||
// Main
|
|
||||||
//=============================================================================
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 original = texture(u_image0, v_texCoord);
|
|
||||||
|
|
||||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
|
||||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
|
||||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
|
||||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
|
||||||
|
|
||||||
vec3 result;
|
|
||||||
|
|
||||||
if (u_int0 == MODE_COLORIZE) {
|
|
||||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
|
||||||
fragColor = vec4(result, original.a);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
|
||||||
? rgb2hsl(original.rgb)
|
|
||||||
: rgb2hsb(original.rgb);
|
|
||||||
|
|
||||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
|
||||||
|
|
||||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
|
||||||
weight = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weight > EPSILON) {
|
|
||||||
float h = fract(hsx.x + hueShift * weight);
|
|
||||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
|
||||||
float v = (u_int1 == COLORSPACE_HSL)
|
|
||||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
|
||||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
|
||||||
|
|
||||||
vec3 adjusted = vec3(h, s, v);
|
|
||||||
result = (u_int1 == COLORSPACE_HSL)
|
|
||||||
? hsl2rgb(adjusted)
|
|
||||||
: hsb2rgb(adjusted);
|
|
||||||
} else {
|
|
||||||
result = original.rgb;
|
|
||||||
}
|
|
||||||
|
|
||||||
fragColor = vec4(result, original.a);
|
|
||||||
}
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
#pragma passes 2
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
// Blur type constants
|
|
||||||
const int BLUR_GAUSSIAN = 0;
|
|
||||||
const int BLUR_BOX = 1;
|
|
||||||
const int BLUR_RADIAL = 2;
|
|
||||||
|
|
||||||
// Radial blur config
|
|
||||||
const int RADIAL_SAMPLES = 12;
|
|
||||||
const float RADIAL_STRENGTH = 0.0003;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform vec2 u_resolution;
|
|
||||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
|
||||||
uniform float u_float0; // Blur radius/amount
|
|
||||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
layout(location = 0) out vec4 fragColor0;
|
|
||||||
|
|
||||||
float gaussian(float x, float sigma) {
|
|
||||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec2 texelSize = 1.0 / u_resolution;
|
|
||||||
float radius = max(u_float0, 0.0);
|
|
||||||
|
|
||||||
// Radial (angular) blur - single pass, doesn't use separable
|
|
||||||
if (u_int0 == BLUR_RADIAL) {
|
|
||||||
// Only execute on first pass
|
|
||||||
if (u_pass > 0) {
|
|
||||||
fragColor0 = texture(u_image0, v_texCoord);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
vec2 center = vec2(0.5);
|
|
||||||
vec2 dir = v_texCoord - center;
|
|
||||||
float dist = length(dir);
|
|
||||||
|
|
||||||
if (dist < 1e-4) {
|
|
||||||
fragColor0 = texture(u_image0, v_texCoord);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
vec4 sum = vec4(0.0);
|
|
||||||
float totalWeight = 0.0;
|
|
||||||
float angleStep = radius * RADIAL_STRENGTH;
|
|
||||||
|
|
||||||
dir /= dist;
|
|
||||||
|
|
||||||
float cosStep = cos(angleStep);
|
|
||||||
float sinStep = sin(angleStep);
|
|
||||||
|
|
||||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
|
||||||
vec2 rotDir = vec2(
|
|
||||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
|
||||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
|
||||||
vec2 uv = center + rotDir * dist;
|
|
||||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
|
||||||
sum += texture(u_image0, uv) * w;
|
|
||||||
totalWeight += w;
|
|
||||||
|
|
||||||
rotDir = vec2(
|
|
||||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
|
||||||
rotDir.x * sinStep + rotDir.y * cosStep
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fragColor0 = sum / max(totalWeight, 0.001);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Separable Gaussian / Box blur
|
|
||||||
int samples = int(ceil(radius));
|
|
||||||
|
|
||||||
if (samples == 0) {
|
|
||||||
fragColor0 = texture(u_image0, v_texCoord);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
|
||||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
|
||||||
|
|
||||||
vec4 color = vec4(0.0);
|
|
||||||
float totalWeight = 0.0;
|
|
||||||
float sigma = radius / 2.0;
|
|
||||||
|
|
||||||
for (int i = -samples; i <= samples; i++) {
|
|
||||||
vec2 offset = dir * float(i) * texelSize;
|
|
||||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
|
||||||
|
|
||||||
float weight;
|
|
||||||
if (u_int0 == BLUR_GAUSSIAN) {
|
|
||||||
weight = gaussian(float(i), sigma);
|
|
||||||
} else {
|
|
||||||
// BLUR_BOX
|
|
||||||
weight = 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
color += sample_color * weight;
|
|
||||||
totalWeight += weight;
|
|
||||||
}
|
|
||||||
|
|
||||||
fragColor0 = color / totalWeight;
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
layout(location = 0) out vec4 fragColor0;
|
|
||||||
layout(location = 1) out vec4 fragColor1;
|
|
||||||
layout(location = 2) out vec4 fragColor2;
|
|
||||||
layout(location = 3) out vec4 fragColor3;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 color = texture(u_image0, v_texCoord);
|
|
||||||
// Output each channel as grayscale to separate render targets
|
|
||||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
|
||||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
|
||||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
|
||||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
|
||||||
}
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
// Levels Adjustment
|
|
||||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
|
||||||
// u_float0: input black (0-255) default: 0
|
|
||||||
// u_float1: input white (0-255) default: 255
|
|
||||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
|
||||||
// u_float3: output black (0-255) default: 0
|
|
||||||
// u_float4: output white (0-255) default: 255
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform int u_int0;
|
|
||||||
uniform float u_float0;
|
|
||||||
uniform float u_float1;
|
|
||||||
uniform float u_float2;
|
|
||||||
uniform float u_float3;
|
|
||||||
uniform float u_float4;
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
out vec4 fragColor;
|
|
||||||
|
|
||||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
|
||||||
float inRange = max(inWhite - inBlack, 0.0001);
|
|
||||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
|
||||||
result = pow(result, vec3(1.0 / gamma));
|
|
||||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
|
||||||
float inRange = max(inWhite - inBlack, 0.0001);
|
|
||||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
|
||||||
result = pow(result, 1.0 / gamma);
|
|
||||||
result = mix(outBlack, outWhite, result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec4 texColor = texture(u_image0, v_texCoord);
|
|
||||||
vec3 color = texColor.rgb;
|
|
||||||
|
|
||||||
float inBlack = u_float0 / 255.0;
|
|
||||||
float inWhite = u_float1 / 255.0;
|
|
||||||
float gamma = u_float2;
|
|
||||||
float outBlack = u_float3 / 255.0;
|
|
||||||
float outWhite = u_float4 / 255.0;
|
|
||||||
|
|
||||||
vec3 result;
|
|
||||||
|
|
||||||
if (u_int0 == 0) {
|
|
||||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
|
||||||
}
|
|
||||||
else if (u_int0 == 1) {
|
|
||||||
result = color;
|
|
||||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
|
||||||
}
|
|
||||||
else if (u_int0 == 2) {
|
|
||||||
result = color;
|
|
||||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
|
||||||
}
|
|
||||||
else if (u_int0 == 3) {
|
|
||||||
result = color;
|
|
||||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
result = color;
|
|
||||||
}
|
|
||||||
|
|
||||||
fragColor = vec4(result, texColor.a);
|
|
||||||
}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
# GLSL Shader Sources
|
|
||||||
|
|
||||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
|
||||||
|
|
||||||
## File Naming Convention
|
|
||||||
|
|
||||||
`{Blueprint_Name}_{node_id}.frag`
|
|
||||||
|
|
||||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
|
||||||
- **node_id**: The GLSLShader node ID within the subgraph
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Extract shaders from blueprint JSONs to this folder
|
|
||||||
python update_blueprints.py extract
|
|
||||||
|
|
||||||
# Patch edited shaders back into blueprint JSONs
|
|
||||||
python update_blueprints.py patch
|
|
||||||
```
|
|
||||||
|
|
||||||
## Workflow
|
|
||||||
|
|
||||||
1. Run `extract` to pull current shaders from JSONs
|
|
||||||
2. Edit `.frag` files
|
|
||||||
3. Run `patch` to update the blueprint JSONs
|
|
||||||
4. Test
|
|
||||||
5. Commit both `.frag` files and updated JSONs
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform vec2 u_resolution;
|
|
||||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
layout(location = 0) out vec4 fragColor0;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec2 texel = 1.0 / u_resolution;
|
|
||||||
|
|
||||||
// Sample center and neighbors
|
|
||||||
vec4 center = texture(u_image0, v_texCoord);
|
|
||||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
|
||||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
|
||||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
|
||||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
|
||||||
|
|
||||||
// Edge enhancement (Laplacian)
|
|
||||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
|
||||||
|
|
||||||
// Add edges back scaled by strength
|
|
||||||
vec4 sharpened = center + edges * u_float0;
|
|
||||||
|
|
||||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
#version 300 es
|
|
||||||
precision highp float;
|
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
|
||||||
uniform vec2 u_resolution;
|
|
||||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
|
||||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
|
||||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
|
||||||
layout(location = 0) out vec4 fragColor0;
|
|
||||||
|
|
||||||
float gaussian(float x, float sigma) {
|
|
||||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
|
||||||
}
|
|
||||||
|
|
||||||
float getLuminance(vec3 color) {
|
|
||||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
vec2 texel = 1.0 / u_resolution;
|
|
||||||
float radius = max(u_float1, 0.5);
|
|
||||||
float amount = u_float0;
|
|
||||||
float threshold = u_float2;
|
|
||||||
|
|
||||||
vec4 original = texture(u_image0, v_texCoord);
|
|
||||||
|
|
||||||
// Gaussian blur for the "unsharp" mask
|
|
||||||
int samples = int(ceil(radius));
|
|
||||||
float sigma = radius / 2.0;
|
|
||||||
|
|
||||||
vec4 blurred = vec4(0.0);
|
|
||||||
float totalWeight = 0.0;
|
|
||||||
|
|
||||||
for (int x = -samples; x <= samples; x++) {
|
|
||||||
for (int y = -samples; y <= samples; y++) {
|
|
||||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
|
||||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
|
||||||
|
|
||||||
float dist = length(vec2(float(x), float(y)));
|
|
||||||
float weight = gaussian(dist, sigma);
|
|
||||||
blurred += sample_color * weight;
|
|
||||||
totalWeight += weight;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
blurred /= totalWeight;
|
|
||||||
|
|
||||||
// Unsharp mask = original - blurred
|
|
||||||
vec3 mask = original.rgb - blurred.rgb;
|
|
||||||
|
|
||||||
// Luminance-based threshold with smooth falloff
|
|
||||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
|
||||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
|
||||||
mask *= thresholdScale;
|
|
||||||
|
|
||||||
// Sharpen: original + mask * amount
|
|
||||||
vec3 sharpened = original.rgb + mask * amount;
|
|
||||||
|
|
||||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
|
||||||
}
|
|
||||||
@@ -1,159 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Shader Blueprint Updater
|
|
||||||
|
|
||||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
|
||||||
|
|
||||||
File naming convention:
|
|
||||||
{Blueprint Name}_{node_id}.frag
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
|
||||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
|
||||||
python update_blueprints.py # Same as patch (default)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
GLSL_DIR = Path(__file__).parent
|
|
||||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
|
||||||
|
|
||||||
|
|
||||||
def get_blueprint_files():
|
|
||||||
"""Get all blueprint JSON files."""
|
|
||||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_filename(name):
|
|
||||||
"""Convert blueprint name to safe filename."""
|
|
||||||
return re.sub(r'[^\w\-]', '_', name)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_shaders():
|
|
||||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
|
||||||
extracted = 0
|
|
||||||
for json_path in get_blueprint_files():
|
|
||||||
blueprint_name = json_path.stem
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(json_path, 'r') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
|
||||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find GLSLShader nodes in subgraphs
|
|
||||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
|
||||||
for node in subgraph.get('nodes', []):
|
|
||||||
if node.get('type') == 'GLSLShader':
|
|
||||||
node_id = node.get('id')
|
|
||||||
widgets = node.get('widgets_values', [])
|
|
||||||
|
|
||||||
# Find shader code (first string that looks like GLSL)
|
|
||||||
for widget in widgets:
|
|
||||||
if isinstance(widget, str) and widget.startswith('#version'):
|
|
||||||
safe_name = sanitize_filename(blueprint_name)
|
|
||||||
frag_name = f"{safe_name}_{node_id}.frag"
|
|
||||||
frag_path = GLSL_DIR / frag_name
|
|
||||||
|
|
||||||
with open(frag_path, 'w') as f:
|
|
||||||
f.write(widget)
|
|
||||||
|
|
||||||
logger.info(" Extracted: %s", frag_name)
|
|
||||||
extracted += 1
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("\nExtracted %d shader(s)", extracted)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_shaders():
|
|
||||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
|
||||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
|
||||||
shader_updates = {}
|
|
||||||
|
|
||||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
|
||||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
|
||||||
parts = frag_path.stem.rsplit('_', 1)
|
|
||||||
if len(parts) != 2:
|
|
||||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
blueprint_name, node_id_str = parts
|
|
||||||
|
|
||||||
try:
|
|
||||||
node_id = int(node_id_str)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
with open(frag_path, 'r') as f:
|
|
||||||
shader_code = f.read()
|
|
||||||
|
|
||||||
if blueprint_name not in shader_updates:
|
|
||||||
shader_updates[blueprint_name] = []
|
|
||||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
|
||||||
|
|
||||||
# Apply updates to JSON files
|
|
||||||
patched = 0
|
|
||||||
for json_path in get_blueprint_files():
|
|
||||||
blueprint_name = sanitize_filename(json_path.stem)
|
|
||||||
|
|
||||||
if blueprint_name not in shader_updates:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(json_path, 'r') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
|
||||||
logger.error("Error reading %s: %s", json_path.name, e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
modified = False
|
|
||||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
|
||||||
# Find the node and update
|
|
||||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
|
||||||
for node in subgraph.get('nodes', []):
|
|
||||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
|
||||||
widgets = node.get('widgets_values', [])
|
|
||||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
|
||||||
widgets[0] = shader_code
|
|
||||||
modified = True
|
|
||||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
|
||||||
patched += 1
|
|
||||||
|
|
||||||
if modified:
|
|
||||||
with open(json_path, 'w') as f:
|
|
||||||
json.dump(data, f)
|
|
||||||
|
|
||||||
if patched == 0:
|
|
||||||
logger.info("No changes to apply.")
|
|
||||||
else:
|
|
||||||
logger.info("\nPatched %d shader(s)", patched)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
command = "patch"
|
|
||||||
else:
|
|
||||||
command = sys.argv[1].lower()
|
|
||||||
|
|
||||||
if command == "extract":
|
|
||||||
logger.info("Extracting shaders from blueprints...")
|
|
||||||
extract_shaders()
|
|
||||||
elif command in ("patch", "update", "apply"):
|
|
||||||
logger.info("Patching shaders into blueprints...")
|
|
||||||
patch_shaders()
|
|
||||||
else:
|
|
||||||
logger.info(__doc__)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
|||||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
|||||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
|||||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
|
||||||
@@ -146,7 +146,6 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||||||
|
|
||||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
@@ -160,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
DynamicVRAM = "dynamic_vram"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
@@ -260,4 +260,4 @@ else:
|
|||||||
args.fast = set(args.fast)
|
args.fast = set(args.fast)
|
||||||
|
|
||||||
def enables_dynamic_vram():
|
def enables_dynamic_vram():
|
||||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||||
|
|||||||
@@ -176,8 +176,6 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
Available after ComfyUI frontend v1.13.4
|
Available after ComfyUI frontend v1.13.4
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
gradient_stops: NotRequired[list[list[float]]]
|
|
||||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
|
|||||||
@@ -4,25 +4,6 @@ import comfy.utils
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def is_equal(x, y):
|
|
||||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
|
||||||
return torch.equal(x, y)
|
|
||||||
elif isinstance(x, dict) and isinstance(y, dict):
|
|
||||||
if x.keys() != y.keys():
|
|
||||||
return False
|
|
||||||
return all(is_equal(x[k], y[k]) for k in x)
|
|
||||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
|
||||||
if type(x) is not type(y) or len(x) != len(y):
|
|
||||||
return False
|
|
||||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return x == y
|
|
||||||
except Exception:
|
|
||||||
logging.warning("comparison issue with COND")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@@ -103,7 +84,7 @@ class CONDConstant(CONDRegular):
|
|||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if not is_equal(self.cond, other.cond):
|
if self.cond != other.cond:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
return # substep from multi-step sampler: keep self._step from the last full step
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
self._step = int(matches[0].item())
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
|
|||||||
@@ -776,10 +776,3 @@ class ChromaRadiance(LatentFormat):
|
|||||||
|
|
||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
class ZImagePixelSpace(ChromaRadiance):
|
|
||||||
"""Pixel-space latent format for ZImage DCT variant.
|
|
||||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
|||||||
if source_attention_mask.ndim == 2:
|
if source_attention_mask.ndim == 2:
|
||||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
|
||||||
|
x = self.in_proj(self.embed(target_input_ids))
|
||||||
context = source_hidden_states
|
context = source_hidden_states
|
||||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
|
||||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_embeddings = self.rotary_emb(x, position_ids)
|
position_embeddings = self.rotary_emb(x, position_ids)
|
||||||
|
|||||||
@@ -152,7 +152,6 @@ class Chroma(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
transformer_options = transformer_options.copy()
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
@@ -229,7 +228,6 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
|
|||||||
@@ -196,9 +196,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
transformer_patches = transformer_options.get("patches", {})
|
|
||||||
extra_options = transformer_options.copy()
|
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
@@ -227,12 +224,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
if "attn1_output_patch" in transformer_patches:
|
|
||||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
|
||||||
patch = transformer_patches["attn1_output_patch"]
|
|
||||||
for p in patch:
|
|
||||||
attn = p(attn, extra_options)
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
@@ -312,9 +303,6 @@ class SingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mod = vec
|
mod = vec
|
||||||
|
|
||||||
transformer_patches = transformer_options.get("patches", {})
|
|
||||||
extra_options = transformer_options.copy()
|
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@@ -324,12 +312,6 @@ class SingleStreamBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
if "attn1_output_patch" in transformer_patches:
|
|
||||||
patch = transformer_patches["attn1_output_patch"]
|
|
||||||
for p in patch:
|
|
||||||
attn = p(attn, extra_options)
|
|
||||||
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
if self.yak_mlp:
|
if self.yak_mlp:
|
||||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
|||||||
@@ -142,7 +142,6 @@ class Flux(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
transformer_options = transformer_options.copy()
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@@ -232,7 +231,6 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@@ -304,7 +304,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
transformer_options = transformer_options.copy()
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
initial_shape = list(img.shape)
|
initial_shape = list(img.shape)
|
||||||
@@ -417,7 +416,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@@ -2,19 +2,13 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.lightricks.model import (
|
from comfy.ldm.lightricks.model import (
|
||||||
ADALN_BASE_PARAMS_COUNT,
|
|
||||||
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
FeedForward,
|
FeedForward,
|
||||||
AdaLayerNormSingle,
|
AdaLayerNormSingle,
|
||||||
PixArtAlphaTextProjection,
|
PixArtAlphaTextProjection,
|
||||||
NormSingleLinearTextProjection,
|
|
||||||
LTXVModel,
|
LTXVModel,
|
||||||
apply_cross_attention_adaln,
|
|
||||||
compute_prompt_timestep,
|
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
class CompressedTimestep:
|
class CompressedTimestep:
|
||||||
@@ -92,8 +86,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
v_context_dim=None,
|
v_context_dim=None,
|
||||||
a_context_dim=None,
|
a_context_dim=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
apply_gated_attention=False,
|
|
||||||
cross_attention_adaln=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -101,7 +93,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
self.cross_attention_adaln = cross_attention_adaln
|
|
||||||
|
|
||||||
self.attn1 = CrossAttention(
|
self.attn1 = CrossAttention(
|
||||||
query_dim=v_dim,
|
query_dim=v_dim,
|
||||||
@@ -109,7 +100,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
dim_head=vd_head,
|
dim_head=vd_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -120,7 +110,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -132,7 +121,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=v_heads,
|
heads=v_heads,
|
||||||
dim_head=vd_head,
|
dim_head=vd_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -143,7 +131,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -156,7 +143,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -169,7 +155,6 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -182,16 +167,11 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
|
||||||
self.audio_scale_shift_table = nn.Parameter(
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cross_attention_adaln:
|
|
||||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
|
||||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
@@ -234,30 +214,10 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
return (*scale_shift_ada_values, *gate_ada_values)
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
|
|
||||||
def _apply_text_cross_attention(
|
|
||||||
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
|
||||||
timestep, prompt_timestep, attention_mask, transformer_options,
|
|
||||||
):
|
|
||||||
"""Apply text cross-attention, with optional ADaLN modulation."""
|
|
||||||
if self.cross_attention_adaln:
|
|
||||||
shift_q, scale_q, gate = self.get_ada_values(
|
|
||||||
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
|
||||||
)
|
|
||||||
return apply_cross_attention_adaln(
|
|
||||||
x, context, attn, shift_q, scale_q, gate,
|
|
||||||
prompt_scale_shift_table, prompt_timestep,
|
|
||||||
attention_mask, transformer_options,
|
|
||||||
)
|
|
||||||
return attn(
|
|
||||||
comfy.ldm.common_dit.rms_norm(x), context=context,
|
|
||||||
mask=attention_mask, transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||||
v_prompt_timestep=None, a_prompt_timestep=None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@@ -273,17 +233,13 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
del vshift_msa, vscale_msa
|
del vshift_msa, vscale_msa
|
||||||
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||||
del norm_vx
|
del norm_vx
|
||||||
# video cross-attention
|
# video cross-attention
|
||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
vx.addcmul_(attn1_out, vgate_msa)
|
vx.addcmul_(attn1_out, vgate_msa)
|
||||||
del vgate_msa, attn1_out
|
del vgate_msa, attn1_out
|
||||||
vx.add_(self._apply_text_cross_attention(
|
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
vx, v_context, self.attn2, self.scale_shift_table,
|
|
||||||
getattr(self, 'prompt_scale_shift_table', None),
|
|
||||||
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if run_ax:
|
if run_ax:
|
||||||
@@ -297,11 +253,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||||
ax.addcmul_(attn1_out, agate_msa)
|
ax.addcmul_(attn1_out, agate_msa)
|
||||||
del agate_msa, attn1_out
|
del agate_msa, attn1_out
|
||||||
ax.add_(self._apply_text_cross_attention(
|
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
|
||||||
getattr(self, 'audio_prompt_scale_shift_table', None),
|
|
||||||
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# video - audio cross attention.
|
# video - audio cross attention.
|
||||||
if run_a2v or run_v2a:
|
if run_a2v or run_v2a:
|
||||||
@@ -398,9 +350,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier=1000.0,
|
timestep_scale_multiplier=1000.0,
|
||||||
av_ca_timestep_scale_multiplier=1.0,
|
av_ca_timestep_scale_multiplier=1.0,
|
||||||
apply_gated_attention=False,
|
|
||||||
caption_proj_before_connector=False,
|
|
||||||
cross_attention_adaln=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -412,7 +361,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
self.audio_attention_head_dim = audio_attention_head_dim
|
self.audio_attention_head_dim = audio_attention_head_dim
|
||||||
self.audio_num_attention_heads = audio_num_attention_heads
|
self.audio_num_attention_heads = audio_num_attention_heads
|
||||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||||
self.apply_gated_attention = apply_gated_attention
|
|
||||||
|
|
||||||
# Calculate audio dimensions
|
# Calculate audio dimensions
|
||||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||||
@@ -437,8 +385,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
vae_scale_factors=vae_scale_factors,
|
vae_scale_factors=vae_scale_factors,
|
||||||
use_middle_indices_grid=use_middle_indices_grid,
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
caption_proj_before_connector=caption_proj_before_connector,
|
|
||||||
cross_attention_adaln=cross_attention_adaln,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -453,28 +399,14 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio-specific AdaLN
|
# Audio-specific AdaLN
|
||||||
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
self.audio_inner_dim,
|
self.audio_inner_dim,
|
||||||
embedding_coefficient=audio_embedding_coefficient,
|
|
||||||
use_additional_conditions=False,
|
use_additional_conditions=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cross_attention_adaln:
|
|
||||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
|
||||||
self.audio_inner_dim,
|
|
||||||
embedding_coefficient=2,
|
|
||||||
use_additional_conditions=False,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=self.operations,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.audio_prompt_adaln_single = None
|
|
||||||
|
|
||||||
num_scale_shift_values = 4
|
num_scale_shift_values = 4
|
||||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
@@ -510,18 +442,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio caption projection
|
# Audio caption projection
|
||||||
if self.caption_proj_before_connector:
|
|
||||||
if self.caption_projection_first_linear:
|
|
||||||
self.audio_caption_projection = NormSingleLinearTextProjection(
|
|
||||||
in_features=self.caption_channels,
|
|
||||||
hidden_size=self.audio_inner_dim,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=self.operations,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.audio_caption_projection = lambda a: a
|
|
||||||
else:
|
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=self.caption_channels,
|
in_features=self.caption_channels,
|
||||||
hidden_size=self.audio_inner_dim,
|
hidden_size=self.audio_inner_dim,
|
||||||
@@ -530,55 +450,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
|
||||||
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
|
||||||
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
|
||||||
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
|
||||||
num_layers = kwargs.get("connector_num_layers", 2)
|
|
||||||
|
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
|
||||||
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
|
||||||
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
|
||||||
num_layers=num_layers,
|
|
||||||
split_rope=connector_split_rope,
|
|
||||||
double_precision_rope=True,
|
|
||||||
apply_gated_attention=connector_gated_attention,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=self.operations,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
|
||||||
attention_head_dim=attention_head_dim,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_layers=num_layers,
|
|
||||||
split_rope=connector_split_rope,
|
|
||||||
double_precision_rope=True,
|
|
||||||
apply_gated_attention=connector_gated_attention,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=self.operations,
|
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess_text_embeds(self, context, unprocessed=False):
|
|
||||||
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
|
||||||
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
|
||||||
if not unprocessed:
|
|
||||||
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
|
||||||
return context
|
|
||||||
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
|
||||||
context_vid = context[:, :, :self.cross_attention_dim]
|
|
||||||
context_audio = context[:, :, self.cross_attention_dim:]
|
|
||||||
else:
|
|
||||||
context_vid = context
|
|
||||||
context_audio = context
|
|
||||||
if self.caption_proj_before_connector:
|
|
||||||
context_vid = self.caption_projection(context_vid)
|
|
||||||
context_audio = self.audio_caption_projection(context_audio)
|
|
||||||
out_vid = self.video_embeddings_connector(context_vid)[0]
|
|
||||||
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
|
||||||
return torch.concat((out_vid, out_audio), dim=-1)
|
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
"""Initialize transformer blocks for LTXAV."""
|
"""Initialize transformer blocks for LTXAV."""
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
@@ -592,8 +463,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
ad_head=self.audio_attention_head_dim,
|
ad_head=self.audio_attention_head_dim,
|
||||||
v_context_dim=self.cross_attention_dim,
|
v_context_dim=self.cross_attention_dim,
|
||||||
a_context_dim=self.audio_cross_attention_dim,
|
a_context_dim=self.audio_cross_attention_dim,
|
||||||
apply_gated_attention=self.apply_gated_attention,
|
|
||||||
cross_attention_adaln=self.cross_attention_adaln,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
@@ -715,10 +584,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
|
||||||
v_prompt_timestep = compute_prompt_timestep(
|
|
||||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
@@ -729,25 +594,25 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
# Cross-attention timesteps - compress these too
|
# Cross-attention timesteps - compress these too
|
||||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
timestep.max().expand_as(a_timestep_flat),
|
a_timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
a_timestep.max().expand_as(timestep_flat),
|
timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
timestep_flat * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
a_timestep_flat * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
@@ -771,40 +636,29 @@ class LTXAVModel(LTXVModel):
|
|||||||
# Audio timesteps
|
# Audio timesteps
|
||||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||||
|
|
||||||
a_prompt_timestep = compute_prompt_timestep(
|
|
||||||
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
a_timestep = timestep_scaled
|
a_timestep = timestep_scaled
|
||||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
cross_av_timestep_ss = []
|
cross_av_timestep_ss = []
|
||||||
a_prompt_timestep = None
|
|
||||||
|
|
||||||
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||||
v_embedded_timestep,
|
v_embedded_timestep,
|
||||||
a_embedded_timestep,
|
a_embedded_timestep,
|
||||||
], None
|
]
|
||||||
|
|
||||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
video_dim = vx.shape[-1]
|
|
||||||
audio_dim = ax.shape[-1]
|
|
||||||
|
|
||||||
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
|
||||||
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
|
||||||
|
|
||||||
v_context, a_context = torch.split(
|
v_context, a_context = torch.split(
|
||||||
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
v_context, attention_mask = super()._prepare_context(
|
v_context, attention_mask = super()._prepare_context(
|
||||||
v_context, batch_size, vx, attention_mask
|
v_context, batch_size, vx, attention_mask
|
||||||
)
|
)
|
||||||
if self.caption_proj_before_connector is False:
|
if self.audio_caption_projection is not None:
|
||||||
a_context = self.audio_caption_projection(a_context)
|
a_context = self.audio_caption_projection(a_context)
|
||||||
a_context = a_context.view(batch_size, -1, audio_dim)
|
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||||
|
|
||||||
return [v_context, a_context], attention_mask
|
return [v_context, a_context], attention_mask
|
||||||
|
|
||||||
@@ -848,7 +702,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||||
):
|
):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
@@ -866,9 +720,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
av_ca_v2a_gate_noise_timestep,
|
av_ca_v2a_gate_noise_timestep,
|
||||||
) = timestep[2]
|
) = timestep[2]
|
||||||
|
|
||||||
v_prompt_timestep = timestep[3]
|
|
||||||
a_prompt_timestep = timestep[4]
|
|
||||||
|
|
||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@@ -895,9 +746,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
transformer_options=args["transformer_options"],
|
transformer_options=args["transformer_options"],
|
||||||
self_attention_mask=args.get("self_attention_mask"),
|
|
||||||
v_prompt_timestep=args.get("v_prompt_timestep"),
|
|
||||||
a_prompt_timestep=args.get("a_prompt_timestep"),
|
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -918,9 +766,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
"self_attention_mask": self_attention_mask,
|
|
||||||
"v_prompt_timestep": v_prompt_timestep,
|
|
||||||
"a_prompt_timestep": a_prompt_timestep,
|
|
||||||
},
|
},
|
||||||
{"original_block": block_wrap},
|
{"original_block": block_wrap},
|
||||||
)
|
)
|
||||||
@@ -942,9 +787,6 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
self_attention_mask=self_attention_mask,
|
|
||||||
v_prompt_timestep=v_prompt_timestep,
|
|
||||||
a_prompt_timestep=a_prompt_timestep,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ class BasicTransformerBlock1D(nn.Module):
|
|||||||
d_head,
|
d_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
apply_gated_attention=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -64,7 +63,6 @@ class BasicTransformerBlock1D(nn.Module):
|
|||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
dim_head=d_head,
|
dim_head=d_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -123,7 +121,6 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
positional_embedding_max_pos=[4096],
|
positional_embedding_max_pos=[4096],
|
||||||
causal_temporal_positioning=False,
|
causal_temporal_positioning=False,
|
||||||
num_learnable_registers: Optional[int] = 128,
|
num_learnable_registers: Optional[int] = 128,
|
||||||
apply_gated_attention=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -148,7 +145,6 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_dim,
|
attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=cross_attention_dim,
|
||||||
apply_gated_attention=apply_gated_attention,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -161,9 +157,11 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
self.num_learnable_registers = num_learnable_registers
|
self.num_learnable_registers = num_learnable_registers
|
||||||
if self.num_learnable_registers:
|
if self.num_learnable_registers:
|
||||||
self.learnable_registers = nn.Parameter(
|
self.learnable_registers = nn.Parameter(
|
||||||
torch.empty(
|
torch.rand(
|
||||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
* 2.0
|
||||||
|
- 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_fractional_positions(self, indices_grid):
|
def get_fractional_positions(self, indices_grid):
|
||||||
@@ -236,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||||
dim = self.inner_dim
|
dim = self.inner_dim
|
||||||
n_elem = 2 # 2 because of cos and sin
|
n_elem = 2 # 2 because of cos and sin
|
||||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
@@ -249,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -290,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
indices_grid = indices_grid[None, None, :]
|
indices_grid = indices_grid[None, None, :]
|
||||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||||
|
|
||||||
# 2. Blocks
|
# 2. Blocks
|
||||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import functools
|
import functools
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
@@ -15,8 +14,6 @@ import comfy.ldm.common_dit
|
|||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def _log_base(x, base):
|
def _log_base(x, base):
|
||||||
return np.log(x) / np.log(base)
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
@@ -275,30 +272,6 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class NormSingleLinearTextProjection(nn.Module):
|
|
||||||
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if operations is None:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
self.in_norm = operations.RMSNorm(
|
|
||||||
in_features, eps=1e-6, elementwise_affine=False
|
|
||||||
)
|
|
||||||
self.linear_1 = operations.Linear(
|
|
||||||
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.in_features = in_features
|
|
||||||
|
|
||||||
def forward(self, caption):
|
|
||||||
caption = self.in_norm(caption)
|
|
||||||
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
|
||||||
return self.linear_1(caption)
|
|
||||||
|
|
||||||
|
|
||||||
class GELU_approx(nn.Module):
|
class GELU_approx(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -367,7 +340,6 @@ class CrossAttention(nn.Module):
|
|||||||
dim_head=64,
|
dim_head=64,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
apply_gated_attention=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -387,12 +359,6 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
# Optional per-head gating
|
|
||||||
if apply_gated_attention:
|
|
||||||
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
self.to_gate_logits = None
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
@@ -414,30 +380,16 @@ class CrossAttention(nn.Module):
|
|||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
|
|
||||||
# Apply per-head gating if enabled
|
|
||||||
if self.to_gate_logits is not None:
|
|
||||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
|
||||||
b, t, _ = out.shape
|
|
||||||
out = out.view(b, t, self.heads, self.dim_head)
|
|
||||||
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
|
||||||
out = out * gates.unsqueeze(-1)
|
|
||||||
out = out.view(b, t, self.heads * self.dim_head)
|
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
|
||||||
ADALN_BASE_PARAMS_COUNT = 6
|
|
||||||
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
self.cross_attention_adaln = cross_attention_adaln
|
|
||||||
self.attn1 = CrossAttention(
|
self.attn1 = CrossAttention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
@@ -461,24 +413,17 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
if cross_attention_adaln:
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
|
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||||
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
|
del attn1_input
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
|
||||||
|
|
||||||
if self.cross_attention_adaln:
|
|
||||||
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
|
||||||
x += apply_cross_attention_adaln(
|
|
||||||
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
|
||||||
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x)
|
y = comfy.ldm.common_dit.rms_norm(x)
|
||||||
@@ -487,47 +432,6 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
|
||||||
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
|
||||||
|
|
||||||
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
|
||||||
over text tokens. Returns None when *adaln_module* is None.
|
|
||||||
"""
|
|
||||||
if adaln_module is None:
|
|
||||||
return None
|
|
||||||
ts_input = (
|
|
||||||
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
|
||||||
if timestep_scaled.dim() > 1
|
|
||||||
else timestep_scaled.flatten()
|
|
||||||
)
|
|
||||||
prompt_ts, _ = adaln_module(
|
|
||||||
ts_input,
|
|
||||||
{"resolution": None, "aspect_ratio": None},
|
|
||||||
batch_size=batch_size,
|
|
||||||
hidden_dtype=hidden_dtype,
|
|
||||||
)
|
|
||||||
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_adaln(
|
|
||||||
x, context, attn, q_shift, q_scale, q_gate,
|
|
||||||
prompt_scale_shift_table, prompt_timestep,
|
|
||||||
attention_mask=None, transformer_options={},
|
|
||||||
):
|
|
||||||
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
|
||||||
|
|
||||||
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
|
||||||
that both regular tensors and CompressedTimestep are supported.
|
|
||||||
"""
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
shift_kv, scale_kv = (
|
|
||||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
|
||||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
|
||||||
).unbind(dim=2)
|
|
||||||
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
|
||||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
|
||||||
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
|
||||||
|
|
||||||
def get_fractional_positions(indices_grid, max_pos):
|
def get_fractional_positions(indices_grid, max_pos):
|
||||||
n_pos_dims = indices_grid.shape[1]
|
n_pos_dims = indices_grid.shape[1]
|
||||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||||
@@ -649,9 +553,6 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
vae_scale_factors: tuple = (8, 32, 32),
|
vae_scale_factors: tuple = (8, 32, 32),
|
||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier = 1000.0,
|
timestep_scale_multiplier = 1000.0,
|
||||||
caption_proj_before_connector=False,
|
|
||||||
cross_attention_adaln=False,
|
|
||||||
caption_projection_first_linear=True,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -678,9 +579,6 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
self.causal_temporal_positioning = causal_temporal_positioning
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
self.caption_proj_before_connector = caption_proj_before_connector
|
|
||||||
self.cross_attention_adaln = cross_attention_adaln
|
|
||||||
self.caption_projection_first_linear = caption_projection_first_linear
|
|
||||||
|
|
||||||
# Common dimensions
|
# Common dimensions
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
@@ -708,30 +606,10 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
|
||||||
self.adaln_single = AdaLayerNormSingle(
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cross_attention_adaln:
|
|
||||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
|
||||||
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.prompt_adaln_single = None
|
|
||||||
|
|
||||||
if self.caption_proj_before_connector:
|
|
||||||
if self.caption_projection_first_linear:
|
|
||||||
self.caption_projection = NormSingleLinearTextProjection(
|
|
||||||
in_features=self.caption_channels,
|
|
||||||
hidden_size=self.inner_dim,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=self.operations,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.caption_projection = lambda a: a
|
|
||||||
else:
|
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=self.caption_channels,
|
in_features=self.caption_channels,
|
||||||
hidden_size=self.inner_dim,
|
hidden_size=self.inner_dim,
|
||||||
@@ -760,16 +638,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
"""Process input data. Must be implemented by subclasses."""
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
|
||||||
"""Build self-attention mask for per-guide attention attenuation.
|
|
||||||
|
|
||||||
Base implementation returns None (no attenuation). Subclasses that
|
|
||||||
support guide-based attention control should override this.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -784,9 +654,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
if grid_mask is not None:
|
if grid_mask is not None:
|
||||||
timestep = timestep[:, grid_mask]
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
timestep_scaled.flatten(),
|
timestep.flatten(),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
@@ -796,18 +666,14 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
|
||||||
prompt_timestep = compute_prompt_timestep(
|
return timestep, embedded_timestep
|
||||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return timestep, embedded_timestep, prompt_timestep
|
|
||||||
|
|
||||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
"""Prepare context for transformer blocks."""
|
"""Prepare context for transformer blocks."""
|
||||||
if self.caption_proj_before_connector is False:
|
if self.caption_projection is not None:
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
|
|
||||||
context = context.view(batch_size, -1, x.shape[-1])
|
context = context.view(batch_size, -1, x.shape[-1])
|
||||||
|
|
||||||
return context, attention_mask
|
return context, attention_mask
|
||||||
|
|
||||||
def _precompute_freqs_cis(
|
def _precompute_freqs_cis(
|
||||||
@@ -915,25 +781,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
merged_args.update(additional_args)
|
merged_args.update(additional_args)
|
||||||
|
|
||||||
# Prepare timestep and context
|
# Prepare timestep and context
|
||||||
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||||
merged_args["prompt_timestep"] = prompt_timestep
|
|
||||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||||
|
|
||||||
# Prepare attention mask and positional embeddings
|
# Prepare attention mask and positional embeddings
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
# Build self-attention mask for per-guide attenuation
|
|
||||||
self_attention_mask = self._build_guide_self_attention_mask(
|
|
||||||
x, transformer_options, merged_args
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
x = self._process_transformer_blocks(
|
x = self._process_transformer_blocks(
|
||||||
x, context, attention_mask, timestep, pe,
|
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||||
transformer_options=transformer_options,
|
|
||||||
self_attention_mask=self_attention_mask,
|
|
||||||
**merged_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process output
|
# Process output
|
||||||
@@ -957,9 +814,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
causal_temporal_positioning=False,
|
causal_temporal_positioning=False,
|
||||||
vae_scale_factors=(8, 32, 32),
|
vae_scale_factors=(8, 32, 32),
|
||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier=1000.0,
|
timestep_scale_multiplier = 1000.0,
|
||||||
caption_proj_before_connector=False,
|
|
||||||
cross_attention_adaln=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -978,8 +833,6 @@ class LTXVModel(LTXBaseModel):
|
|||||||
vae_scale_factors=vae_scale_factors,
|
vae_scale_factors=vae_scale_factors,
|
||||||
use_middle_indices_grid=use_middle_indices_grid,
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
caption_proj_before_connector=caption_proj_before_connector,
|
|
||||||
cross_attention_adaln=cross_attention_adaln,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@@ -988,6 +841,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
def _init_model_components(self, device, dtype, **kwargs):
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
"""Initialize LTXV-specific components."""
|
"""Initialize LTXV-specific components."""
|
||||||
|
# No additional components needed for LTXV beyond base class
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
@@ -999,7 +853,6 @@ class LTXVModel(LTXBaseModel):
|
|||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.attention_head_dim,
|
self.attention_head_dim,
|
||||||
context_dim=self.cross_attention_dim,
|
context_dim=self.cross_attention_dim,
|
||||||
cross_attention_adaln=self.cross_attention_adaln,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
@@ -1037,257 +890,26 @@ class LTXVModel(LTXBaseModel):
|
|||||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
|
|
||||||
# Compute per-guide surviving token counts from guide_attention_entries.
|
|
||||||
# Each entry tracks one guide reference; they are appended in order and
|
|
||||||
# their pre_filter_counts partition the kf_grid_mask.
|
|
||||||
guide_entries = kwargs.get("guide_attention_entries", None)
|
|
||||||
if guide_entries:
|
|
||||||
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
|
||||||
if total_pfc != len(kf_grid_mask):
|
|
||||||
raise ValueError(
|
|
||||||
f"guide pre_filter_counts ({total_pfc}) != "
|
|
||||||
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
|
||||||
)
|
|
||||||
resolved_entries = []
|
|
||||||
offset = 0
|
|
||||||
for entry in guide_entries:
|
|
||||||
pfc = entry["pre_filter_count"]
|
|
||||||
entry_mask = kf_grid_mask[offset:offset + pfc]
|
|
||||||
surviving = int(entry_mask.sum().item())
|
|
||||||
resolved_entries.append({
|
|
||||||
**entry,
|
|
||||||
"surviving_count": surviving,
|
|
||||||
})
|
|
||||||
offset += pfc
|
|
||||||
additional_args["resolved_guide_entries"] = resolved_entries
|
|
||||||
|
|
||||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
# Total surviving guide tokens (all guides)
|
|
||||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
return x, pixel_coords, additional_args
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||||
"""Build self-attention mask for per-guide attention attenuation.
|
|
||||||
|
|
||||||
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
|
||||||
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
|
||||||
attention for each guide reference independently.
|
|
||||||
|
|
||||||
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
|
||||||
spatial masks, or no guide tokens).
|
|
||||||
"""
|
|
||||||
if isinstance(x, list):
|
|
||||||
# AV model: x = [vx, ax]; use vx for token count and device
|
|
||||||
total_tokens = x[0].shape[1]
|
|
||||||
device = x[0].device
|
|
||||||
dtype = x[0].dtype
|
|
||||||
else:
|
|
||||||
total_tokens = x.shape[1]
|
|
||||||
device = x.device
|
|
||||||
dtype = x.dtype
|
|
||||||
|
|
||||||
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
|
||||||
if num_guide_tokens == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
|
||||||
if not resolved_entries:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if any attenuation is actually needed
|
|
||||||
needs_attenuation = any(
|
|
||||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
|
||||||
for e in resolved_entries
|
|
||||||
)
|
|
||||||
if not needs_attenuation:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Build per-guide-token weights for all tracked guide tokens.
|
|
||||||
# Guides are appended in order at the end of the sequence.
|
|
||||||
guide_start = total_tokens - num_guide_tokens
|
|
||||||
all_weights = []
|
|
||||||
total_tracked = 0
|
|
||||||
|
|
||||||
for entry in resolved_entries:
|
|
||||||
surviving = entry["surviving_count"]
|
|
||||||
if surviving == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
strength = entry["strength"]
|
|
||||||
pixel_mask = entry.get("pixel_mask")
|
|
||||||
latent_shape = entry.get("latent_shape")
|
|
||||||
|
|
||||||
if pixel_mask is not None and latent_shape is not None:
|
|
||||||
f_lat, h_lat, w_lat = latent_shape
|
|
||||||
per_token = self._downsample_mask_to_latent(
|
|
||||||
pixel_mask.to(device=device, dtype=dtype),
|
|
||||||
f_lat, h_lat, w_lat,
|
|
||||||
)
|
|
||||||
# per_token shape: (B, f_lat*h_lat*w_lat).
|
|
||||||
# Collapse batch dim — the mask is assumed identical across the
|
|
||||||
# batch; validate and take the first element to get (1, tokens).
|
|
||||||
if per_token.shape[0] > 1:
|
|
||||||
ref = per_token[0]
|
|
||||||
for bi in range(1, per_token.shape[0]):
|
|
||||||
if not torch.equal(ref, per_token[bi]):
|
|
||||||
logger.warning(
|
|
||||||
"pixel_mask differs across batch elements; "
|
|
||||||
"using first element only."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
per_token = per_token[:1]
|
|
||||||
# `surviving` is the post-grid_mask token count.
|
|
||||||
# Clamp to surviving to handle any mismatch safely.
|
|
||||||
n_weights = min(per_token.shape[1], surviving)
|
|
||||||
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
|
||||||
else:
|
|
||||||
weights = torch.full(
|
|
||||||
(1, surviving), strength, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
all_weights.append(weights)
|
|
||||||
total_tracked += weights.shape[1]
|
|
||||||
|
|
||||||
if not all_weights:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Concatenate per-token weights for all tracked guides
|
|
||||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
|
||||||
|
|
||||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
|
||||||
if (tracked_weights >= 1.0).all():
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Build the mask: guide tokens are at the end of the sequence.
|
|
||||||
# Tracked guides come first (in order), untracked follow.
|
|
||||||
return self._build_self_attention_mask(
|
|
||||||
total_tokens, num_guide_tokens, total_tracked,
|
|
||||||
tracked_weights, guide_start, device, dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
|
||||||
"""Downsample a pixel-space mask to per-token latent weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
|
||||||
f_lat: Number of latent frames (pre-dilation original count).
|
|
||||||
h_lat: Latent height (pre-dilation original height).
|
|
||||||
w_lat: Latent width (pre-dilation original width).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
|
||||||
"""
|
|
||||||
b = mask.shape[0]
|
|
||||||
f_pix = mask.shape[2]
|
|
||||||
|
|
||||||
# Spatial downsampling: area interpolation per frame
|
|
||||||
spatial_down = torch.nn.functional.interpolate(
|
|
||||||
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
|
||||||
size=(h_lat, w_lat),
|
|
||||||
mode="area",
|
|
||||||
)
|
|
||||||
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
|
||||||
|
|
||||||
# Temporal downsampling: first pixel frame maps to first latent frame,
|
|
||||||
# remaining pixel frames are averaged in groups for causal temporal structure.
|
|
||||||
first_frame = spatial_down[:, :, :1, :, :]
|
|
||||||
if f_pix > 1 and f_lat > 1:
|
|
||||||
remaining_pix = f_pix - 1
|
|
||||||
remaining_lat = f_lat - 1
|
|
||||||
t = remaining_pix // remaining_lat
|
|
||||||
if t < 1:
|
|
||||||
# Fewer pixel frames than latent frames — upsample by repeating
|
|
||||||
# the available pixel frames via nearest interpolation.
|
|
||||||
rest_flat = rearrange(
|
|
||||||
spatial_down[:, :, 1:, :, :],
|
|
||||||
"b 1 f h w -> (b h w) 1 f",
|
|
||||||
)
|
|
||||||
rest_up = torch.nn.functional.interpolate(
|
|
||||||
rest_flat, size=remaining_lat, mode="nearest",
|
|
||||||
)
|
|
||||||
rest = rearrange(
|
|
||||||
rest_up, "(b h w) 1 f -> b 1 f h w",
|
|
||||||
b=b, h=h_lat, w=w_lat,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Trim trailing pixel frames that don't fill a complete group
|
|
||||||
usable = remaining_lat * t
|
|
||||||
rest = rearrange(
|
|
||||||
spatial_down[:, :, 1:1 + usable, :, :],
|
|
||||||
"b 1 (f t) h w -> b 1 f t h w",
|
|
||||||
t=t,
|
|
||||||
)
|
|
||||||
rest = rest.mean(dim=3)
|
|
||||||
latent_mask = torch.cat([first_frame, rest], dim=2)
|
|
||||||
elif f_lat > 1:
|
|
||||||
# Single pixel frame but multiple latent frames — repeat the
|
|
||||||
# single frame across all latent frames.
|
|
||||||
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
|
||||||
else:
|
|
||||||
latent_mask = first_frame
|
|
||||||
|
|
||||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
|
||||||
tracked_weights, guide_start, device, dtype):
|
|
||||||
"""Build a log-space additive self-attention bias mask.
|
|
||||||
|
|
||||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
|
||||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
total_tokens: Total sequence length.
|
|
||||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
|
||||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
|
||||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
|
||||||
guide_start: Index where guide tokens begin in the sequence.
|
|
||||||
device: Target device.
|
|
||||||
dtype: Target dtype.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
|
||||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
|
||||||
"""
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
|
||||||
tracked_end = guide_start + tracked_count
|
|
||||||
|
|
||||||
# Convert weights to log-space bias
|
|
||||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
|
||||||
log_w = torch.full_like(w, finfo.min)
|
|
||||||
positive_mask = w > 0
|
|
||||||
if positive_mask.any():
|
|
||||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
|
||||||
|
|
||||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
|
||||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
|
||||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
|
||||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
|
||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
prompt_timestep = kwargs.get("prompt_timestep", None)
|
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@@ -1297,8 +919,6 @@ class LTXVModel(LTXBaseModel):
|
|||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
self_attention_mask=self_attention_mask,
|
|
||||||
prompt_timestep=prompt_timestep,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
|||||||
CausalityAxis,
|
CausalityAxis,
|
||||||
CausalAudioAutoencoder,
|
CausalAudioAutoencoder,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||||
|
|
||||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
@@ -141,9 +141,6 @@ class AudioVAE(torch.nn.Module):
|
|||||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||||
|
|
||||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
if "bwe" in component_config.vocoder:
|
|
||||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
|
||||||
else:
|
|
||||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
|
|||||||
@@ -822,23 +822,26 @@ class CausalAudioAutoencoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self.get_default_config()
|
config = self._guess_config()
|
||||||
|
|
||||||
|
# Extract encoder and decoder configs from the new format
|
||||||
model_config = config.get("model", {}).get("params", {})
|
model_config = config.get("model", {}).get("params", {})
|
||||||
|
variables_config = config.get("variables", {})
|
||||||
|
|
||||||
self.sampling_rate = model_config.get(
|
self.sampling_rate = variables_config.get(
|
||||||
"sampling_rate", config.get("sampling_rate", 16000)
|
"sampling_rate",
|
||||||
|
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||||
)
|
)
|
||||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||||
decoder_config = model_config.get("decoder", encoder_config)
|
decoder_config = model_config.get("decoder", encoder_config)
|
||||||
|
|
||||||
# Load mel spectrogram parameters
|
# Load mel spectrogram parameters
|
||||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||||
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||||
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||||
|
|
||||||
# Store causality configuration at VAE level (not just in encoder internals)
|
# Store causality configuration at VAE level (not just in encoder internals)
|
||||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||||
|
|
||||||
@@ -847,38 +850,44 @@ class CausalAudioAutoencoder(nn.Module):
|
|||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def get_default_config(self):
|
def _guess_config(self):
|
||||||
ddconfig = {
|
encoder_config = {
|
||||||
"double_z": True,
|
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||||
"mel_bins": 64,
|
|
||||||
"z_channels": 8,
|
|
||||||
"resolution": 256,
|
|
||||||
"downsample_time": False,
|
|
||||||
"in_channels": 2,
|
|
||||||
"out_ch": 2,
|
|
||||||
"ch": 128,
|
"ch": 128,
|
||||||
"ch_mult": [1, 2, 4],
|
"out_ch": 8,
|
||||||
|
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||||
"num_res_blocks": 2,
|
"num_res_blocks": 2,
|
||||||
"attn_resolutions": [],
|
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||||
"dropout": 0.0,
|
"dropout": 0.0,
|
||||||
"mid_block_add_attention": False,
|
"resamp_with_conv": True,
|
||||||
|
"in_channels": 2, # stereo
|
||||||
|
"resolution": 256,
|
||||||
|
"z_channels": 8,
|
||||||
|
"double_z": True,
|
||||||
|
"attn_type": "vanilla",
|
||||||
|
"mid_block_add_attention": False, # Based on metadata: false
|
||||||
"norm_type": "pixel",
|
"norm_type": "pixel",
|
||||||
"causality_axis": "height",
|
"causality_axis": "height", # Based on metadata
|
||||||
|
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder_config = {
|
||||||
|
# Inherits encoder config, can override specific params
|
||||||
|
**encoder_config,
|
||||||
|
"out_ch": 2, # Stereo audio output (2 channels)
|
||||||
|
"give_pre_end": False,
|
||||||
|
"tanh_out": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
"_class_name": "CausalAudioAutoencoder",
|
||||||
|
"sampling_rate": 16000,
|
||||||
"model": {
|
"model": {
|
||||||
"params": {
|
"params": {
|
||||||
"ddconfig": ddconfig,
|
"encoder": encoder_config,
|
||||||
"sampling_rate": 16000,
|
"decoder": decoder_config,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"preprocessing": {
|
|
||||||
"stft": {
|
|
||||||
"filter_length": 1024,
|
|
||||||
"hop_length": 160,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
@@ -15,9 +15,6 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
|||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
def in_meta_context():
|
|
||||||
return torch.device("meta") == torch.empty(0).device
|
|
||||||
|
|
||||||
def mark_conv3d_ended(module):
|
def mark_conv3d_ended(module):
|
||||||
tid = threading.get_ident()
|
tid = threading.get_ident()
|
||||||
for _, m in module.named_modules():
|
for _, m in module.named_modules():
|
||||||
@@ -353,10 +350,6 @@ class Decoder(nn.Module):
|
|||||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||||
if block_name == "compress_all":
|
if block_name == "compress_all":
|
||||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
if block_name == "compress_space":
|
|
||||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
|
||||||
if block_name == "compress_time":
|
|
||||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
|
||||||
|
|
||||||
self.conv_in = make_conv_nd(
|
self.conv_in = make_conv_nd(
|
||||||
dims,
|
dims,
|
||||||
@@ -402,21 +395,17 @@ class Decoder(nn.Module):
|
|||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(2, 1, 1),
|
stride=(2, 1, 1),
|
||||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(1, 2, 2),
|
stride=(1, 2, 2),
|
||||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
@@ -466,15 +455,6 @@ class Decoder(nn.Module):
|
|||||||
output_channel * 2, 0, operations=ops,
|
output_channel * 2, 0, operations=ops,
|
||||||
)
|
)
|
||||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||||
else:
|
|
||||||
self.register_buffer(
|
|
||||||
"last_scale_shift_table",
|
|
||||||
torch.tensor(
|
|
||||||
[0.0, 0.0],
|
|
||||||
device="cpu" if in_meta_context() else None
|
|
||||||
).unsqueeze(1).expand(2, output_channel),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||||
@@ -903,15 +883,6 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.scale_shift_table = nn.Parameter(
|
self.scale_shift_table = nn.Parameter(
|
||||||
torch.randn(4, in_channels) / in_channels**0.5
|
torch.randn(4, in_channels) / in_channels**0.5
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.register_buffer(
|
|
||||||
"scale_shift_table",
|
|
||||||
torch.tensor(
|
|
||||||
[0.0, 0.0, 0.0, 0.0],
|
|
||||||
device="cpu" if in_meta_context() else None
|
|
||||||
).unsqueeze(1).expand(4, in_channels),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.temporal_cache_state={}
|
self.temporal_cache_state={}
|
||||||
|
|
||||||
@@ -1041,6 +1012,9 @@ class processor(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer("std-of-means", torch.empty(128))
|
self.register_buffer("std-of-means", torch.empty(128))
|
||||||
self.register_buffer("mean-of-means", torch.empty(128))
|
self.register_buffer("mean-of-means", torch.empty(128))
|
||||||
|
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||||
|
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||||
|
self.register_buffer("channel", torch.empty(128))
|
||||||
|
|
||||||
def un_normalize(self, x):
|
def un_normalize(self, x):
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
@@ -1053,12 +1027,9 @@ class VideoVAE(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self.get_default_config(version)
|
config = self.guess_config(version)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
|
||||||
self.decode_timestep = config.get("decode_timestep", 0.05)
|
|
||||||
double_z = config.get("double_z", True)
|
double_z = config.get("double_z", True)
|
||||||
latent_log_var = config.get(
|
latent_log_var = config.get(
|
||||||
"latent_log_var", "per_channel" if double_z else "none"
|
"latent_log_var", "per_channel" if double_z else "none"
|
||||||
@@ -1073,7 +1044,6 @@ class VideoVAE(nn.Module):
|
|||||||
latent_log_var=latent_log_var,
|
latent_log_var=latent_log_var,
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
base_channels=config.get("encoder_base_channels", 128),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
@@ -1081,7 +1051,6 @@ class VideoVAE(nn.Module):
|
|||||||
in_channels=config["latent_channels"],
|
in_channels=config["latent_channels"],
|
||||||
out_channels=config.get("out_channels", 3),
|
out_channels=config.get("out_channels", 3),
|
||||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
base_channels=config.get("decoder_base_channels", 128),
|
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
causal=config.get("causal_decoder", False),
|
causal=config.get("causal_decoder", False),
|
||||||
@@ -1091,7 +1060,7 @@ class VideoVAE(nn.Module):
|
|||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def get_default_config(self, version):
|
def guess_config(self, version):
|
||||||
if version == 0:
|
if version == 0:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
@@ -1198,7 +1167,8 @@ class VideoVAE(nn.Module):
|
|||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||||
if self.timestep_conditioning: #TODO: seed
|
if self.timestep_conditioning: #TODO: seed
|
||||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.model_management
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
@@ -14,307 +12,6 @@ def get_padding(kernel_size, dilation=1):
|
|||||||
return int((kernel_size * dilation - dilation) / 2)
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
|
||||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _sinc(x: torch.Tensor):
|
|
||||||
return torch.where(
|
|
||||||
x == 0,
|
|
||||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
|
||||||
torch.sin(math.pi * x) / math.pi / x,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
|
||||||
even = kernel_size % 2 == 0
|
|
||||||
half_size = kernel_size // 2
|
|
||||||
delta_f = 4 * half_width
|
|
||||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
|
||||||
if A > 50.0:
|
|
||||||
beta = 0.1102 * (A - 8.7)
|
|
||||||
elif A >= 21.0:
|
|
||||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
|
||||||
else:
|
|
||||||
beta = 0.0
|
|
||||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
|
||||||
if even:
|
|
||||||
time = torch.arange(-half_size, half_size) + 0.5
|
|
||||||
else:
|
|
||||||
time = torch.arange(kernel_size) - half_size
|
|
||||||
if cutoff == 0:
|
|
||||||
filter_ = torch.zeros_like(time)
|
|
||||||
else:
|
|
||||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
|
||||||
filter_ /= filter_.sum()
|
|
||||||
filter = filter_.view(1, 1, kernel_size)
|
|
||||||
return filter
|
|
||||||
|
|
||||||
|
|
||||||
class LowPassFilter1d(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cutoff=0.5,
|
|
||||||
half_width=0.6,
|
|
||||||
stride=1,
|
|
||||||
padding=True,
|
|
||||||
padding_mode="replicate",
|
|
||||||
kernel_size=12,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if cutoff < -0.0:
|
|
||||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
|
||||||
if cutoff > 0.5:
|
|
||||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.even = kernel_size % 2 == 0
|
|
||||||
self.pad_left = kernel_size // 2 - int(self.even)
|
|
||||||
self.pad_right = kernel_size // 2
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.padding_mode = padding_mode
|
|
||||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
|
||||||
self.register_buffer("filter", filter)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
_, C, _ = x.shape
|
|
||||||
if self.padding:
|
|
||||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
|
||||||
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
|
||||||
|
|
||||||
|
|
||||||
class UpSample1d(nn.Module):
|
|
||||||
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
|
||||||
super().__init__()
|
|
||||||
self.ratio = ratio
|
|
||||||
self.stride = ratio
|
|
||||||
|
|
||||||
if window_type == "hann":
|
|
||||||
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
|
||||||
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
|
||||||
# Uses replicate boundary padding, matching the reference resampler exactly.
|
|
||||||
rolloff = 0.99
|
|
||||||
lowpass_filter_width = 6
|
|
||||||
width = math.ceil(lowpass_filter_width / rolloff)
|
|
||||||
self.kernel_size = 2 * width * ratio + 1
|
|
||||||
self.pad = width
|
|
||||||
self.pad_left = 2 * width * ratio
|
|
||||||
self.pad_right = self.kernel_size - ratio
|
|
||||||
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
|
||||||
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
|
||||||
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
|
||||||
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
|
||||||
else:
|
|
||||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
|
||||||
self.kernel_size = (
|
|
||||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
|
||||||
)
|
|
||||||
self.pad = self.kernel_size // ratio - 1
|
|
||||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
|
||||||
self.pad_right = (
|
|
||||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
|
||||||
)
|
|
||||||
filter = kaiser_sinc_filter1d(
|
|
||||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_buffer("filter", filter, persistent=persistent)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
_, C, _ = x.shape
|
|
||||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
|
||||||
x = self.ratio * F.conv_transpose1d(
|
|
||||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
|
|
||||||
)
|
|
||||||
x = x[..., self.pad_left : -self.pad_right]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DownSample1d(nn.Module):
|
|
||||||
def __init__(self, ratio=2, kernel_size=None):
|
|
||||||
super().__init__()
|
|
||||||
self.ratio = ratio
|
|
||||||
self.kernel_size = (
|
|
||||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
|
||||||
)
|
|
||||||
self.lowpass = LowPassFilter1d(
|
|
||||||
cutoff=0.5 / ratio,
|
|
||||||
half_width=0.6 / ratio,
|
|
||||||
stride=ratio,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.lowpass(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Activation1d(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
activation,
|
|
||||||
up_ratio=2,
|
|
||||||
down_ratio=2,
|
|
||||||
up_kernel_size=12,
|
|
||||||
down_kernel_size=12,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.act = activation
|
|
||||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
|
||||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.upsample(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.downsample(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# BigVGAN v2 activations (Snake / SnakeBeta)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class Snake(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.alpha_logscale = alpha_logscale
|
|
||||||
self.alpha = nn.Parameter(
|
|
||||||
torch.zeros(in_features)
|
|
||||||
if alpha_logscale
|
|
||||||
else torch.ones(in_features) * alpha
|
|
||||||
)
|
|
||||||
self.alpha.requires_grad = alpha_trainable
|
|
||||||
self.eps = 1e-9
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
|
||||||
if self.alpha_logscale:
|
|
||||||
a = torch.exp(a)
|
|
||||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
|
||||||
|
|
||||||
|
|
||||||
class SnakeBeta(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.alpha_logscale = alpha_logscale
|
|
||||||
self.alpha = nn.Parameter(
|
|
||||||
torch.zeros(in_features)
|
|
||||||
if alpha_logscale
|
|
||||||
else torch.ones(in_features) * alpha
|
|
||||||
)
|
|
||||||
self.alpha.requires_grad = alpha_trainable
|
|
||||||
self.beta = nn.Parameter(
|
|
||||||
torch.zeros(in_features)
|
|
||||||
if alpha_logscale
|
|
||||||
else torch.ones(in_features) * alpha
|
|
||||||
)
|
|
||||||
self.beta.requires_grad = alpha_trainable
|
|
||||||
self.eps = 1e-9
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
|
||||||
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
|
||||||
if self.alpha_logscale:
|
|
||||||
a = torch.exp(a)
|
|
||||||
b = torch.exp(b)
|
|
||||||
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class AMPBlock1(torch.nn.Module):
|
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
|
||||||
super().__init__()
|
|
||||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
|
||||||
self.convs1 = nn.ModuleList(
|
|
||||||
[
|
|
||||||
ops.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=dilation[0],
|
|
||||||
padding=get_padding(kernel_size, dilation[0]),
|
|
||||||
),
|
|
||||||
ops.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=dilation[1],
|
|
||||||
padding=get_padding(kernel_size, dilation[1]),
|
|
||||||
),
|
|
||||||
ops.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=dilation[2],
|
|
||||||
padding=get_padding(kernel_size, dilation[2]),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.convs2 = nn.ModuleList(
|
|
||||||
[
|
|
||||||
ops.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=1,
|
|
||||||
padding=get_padding(kernel_size, 1),
|
|
||||||
),
|
|
||||||
ops.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=1,
|
|
||||||
padding=get_padding(kernel_size, 1),
|
|
||||||
),
|
|
||||||
ops.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=1,
|
|
||||||
padding=get_padding(kernel_size, 1),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.acts1 = nn.ModuleList(
|
|
||||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
|
||||||
)
|
|
||||||
self.acts2 = nn.ModuleList(
|
|
||||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
|
||||||
xt = a1(x)
|
|
||||||
xt = c1(xt)
|
|
||||||
xt = a2(xt)
|
|
||||||
xt = c2(xt)
|
|
||||||
x = x + xt
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# HiFi-GAN residual blocks
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock1(torch.nn.Module):
|
class ResBlock1(torch.nn.Module):
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
super(ResBlock1, self).__init__()
|
super(ResBlock1, self).__init__()
|
||||||
@@ -422,7 +119,6 @@ class Vocoder(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||||
|
|
||||||
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=None):
|
def __init__(self, config=None):
|
||||||
@@ -432,39 +128,19 @@ class Vocoder(torch.nn.Module):
|
|||||||
config = self.get_default_config()
|
config = self.get_default_config()
|
||||||
|
|
||||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||||
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||||
stereo = config.get("stereo", True)
|
stereo = config.get("stereo", True)
|
||||||
activation = config.get("activation", "snake")
|
resblock = config.get("resblock", "1")
|
||||||
use_bias_at_final = config.get("use_bias_at_final", True)
|
|
||||||
|
|
||||||
|
|
||||||
# "output_sample_rate" is not present in recent checkpoint configs.
|
|
||||||
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
|
||||||
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
|
||||||
# where upsample_factor = product of all upsample stride lengths,
|
|
||||||
# and mel_hop_length is loaded from the autoencoder config at
|
|
||||||
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
|
||||||
self.output_sample_rate = config.get("output_sample_rate")
|
self.output_sample_rate = config.get("output_sample_rate")
|
||||||
self.resblock = config.get("resblock", "1")
|
|
||||||
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
|
||||||
self.apply_final_activation = config.get("apply_final_activation", True)
|
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
|
||||||
in_channels = 128 if stereo else 64
|
in_channels = 128 if stereo else 64
|
||||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||||
if self.resblock == "1":
|
|
||||||
resblock_cls = ResBlock1
|
|
||||||
elif self.resblock == "2":
|
|
||||||
resblock_cls = ResBlock2
|
|
||||||
elif self.resblock == "AMP1":
|
|
||||||
resblock_cls = AMPBlock1
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
@@ -481,40 +157,25 @@ class Vocoder(torch.nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
if self.resblock == "AMP1":
|
self.resblocks.append(resblock_class(ch, k, d))
|
||||||
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
|
||||||
else:
|
|
||||||
self.resblocks.append(resblock_cls(ch, k, d))
|
|
||||||
|
|
||||||
out_channels = 2 if stereo else 1
|
out_channels = 2 if stereo else 1
|
||||||
if self.resblock == "AMP1":
|
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
|
||||||
self.act_post = Activation1d(act_cls(ch))
|
|
||||||
else:
|
|
||||||
self.act_post = nn.LeakyReLU()
|
|
||||||
|
|
||||||
self.conv_post = ops.Conv1d(
|
|
||||||
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
|
||||||
)
|
|
||||||
|
|
||||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(self):
|
def get_default_config(self):
|
||||||
"""Generate default configuration for the vocoder."""
|
"""Generate default configuration for the vocoder."""
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"resblock_kernel_sizes": [3, 7, 11],
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
"upsample_rates": [5, 4, 2, 2, 2],
|
"upsample_rates": [6, 5, 2, 2, 2],
|
||||||
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
"upsample_initial_channel": 1024,
|
"upsample_initial_channel": 1024,
|
||||||
"stereo": True,
|
"stereo": True,
|
||||||
"resblock": "1",
|
"resblock": "1",
|
||||||
"activation": "snake",
|
|
||||||
"use_bias_at_final": True,
|
|
||||||
"use_tanh_at_final": True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@@ -535,9 +196,7 @@ class Vocoder(torch.nn.Module):
|
|||||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||||
x = self.conv_pre(x)
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
if self.resblock != "AMP1":
|
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
x = self.ups[i](x)
|
x = self.ups[i](x)
|
||||||
xs = None
|
xs = None
|
||||||
@@ -547,167 +206,8 @@ class Vocoder(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
x = xs / self.num_kernels
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
x = self.act_post(x)
|
|
||||||
x = self.conv_post(x)
|
x = self.conv_post(x)
|
||||||
|
|
||||||
if self.apply_final_activation:
|
|
||||||
if self.use_tanh_at_final:
|
|
||||||
x = torch.tanh(x)
|
x = torch.tanh(x)
|
||||||
else:
|
|
||||||
x = torch.clamp(x, -1, 1)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class _STFTFn(nn.Module):
|
|
||||||
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
|
||||||
|
|
||||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
|
||||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
|
||||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
|
||||||
bit-identical to what it was trained on.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
|
||||||
super().__init__()
|
|
||||||
self.hop_length = hop_length
|
|
||||||
self.win_length = win_length
|
|
||||||
n_freqs = filter_length // 2 + 1
|
|
||||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
|
||||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
|
||||||
|
|
||||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
|
||||||
each output frame depends only on past and present input — no lookahead.
|
|
||||||
The STFT is computed by convolving the padded signal with forward_basis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y: Waveform tensor of shape (B, T).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
|
||||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
|
||||||
Computed in float32 for numerical stability, then cast back to
|
|
||||||
the input dtype.
|
|
||||||
"""
|
|
||||||
if y.dim() == 2:
|
|
||||||
y = y.unsqueeze(1) # (B, 1, T)
|
|
||||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
|
||||||
y = F.pad(y, (left_pad, 0))
|
|
||||||
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
|
|
||||||
n_freqs = spec.shape[1] // 2
|
|
||||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
|
||||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
|
||||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
|
||||||
return magnitude, phase
|
|
||||||
|
|
||||||
|
|
||||||
class MelSTFT(nn.Module):
|
|
||||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
|
||||||
|
|
||||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
|
||||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
|
||||||
|
|
||||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
|
||||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
filter_length: int,
|
|
||||||
hop_length: int,
|
|
||||||
win_length: int,
|
|
||||||
n_mel_channels: int,
|
|
||||||
sampling_rate: int,
|
|
||||||
mel_fmin: float,
|
|
||||||
mel_fmax: float,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
|
||||||
|
|
||||||
n_freqs = filter_length // 2 + 1
|
|
||||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
|
||||||
|
|
||||||
def mel_spectrogram(
|
|
||||||
self, y: torch.Tensor
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y: Waveform tensor of shape (B, T).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
|
||||||
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
|
||||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
|
||||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
|
||||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
|
||||||
"""
|
|
||||||
magnitude, phase = self.stft_fn(y)
|
|
||||||
energy = torch.norm(magnitude, dim=1)
|
|
||||||
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
|
|
||||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
|
||||||
return log_mel, magnitude, phase, energy
|
|
||||||
|
|
||||||
|
|
||||||
class VocoderWithBWE(torch.nn.Module):
|
|
||||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
|
||||||
|
|
||||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
|
||||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
vocoder_config = config["vocoder"]
|
|
||||||
bwe_config = config["bwe"]
|
|
||||||
|
|
||||||
self.vocoder = Vocoder(config=vocoder_config)
|
|
||||||
self.bwe_generator = Vocoder(
|
|
||||||
config={**bwe_config, "apply_final_activation": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
|
||||||
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
|
||||||
self.hop_length = bwe_config["hop_length"]
|
|
||||||
|
|
||||||
self.mel_stft = MelSTFT(
|
|
||||||
filter_length=bwe_config["n_fft"],
|
|
||||||
hop_length=bwe_config["hop_length"],
|
|
||||||
win_length=bwe_config["n_fft"],
|
|
||||||
n_mel_channels=bwe_config["num_mels"],
|
|
||||||
sampling_rate=bwe_config["input_sampling_rate"],
|
|
||||||
mel_fmin=0.0,
|
|
||||||
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
|
||||||
)
|
|
||||||
self.resampler = UpSample1d(
|
|
||||||
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
|
||||||
persistent=False,
|
|
||||||
window_type="hann",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _compute_mel(self, audio):
|
|
||||||
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
|
||||||
B, C, T = audio.shape
|
|
||||||
flat = audio.reshape(B * C, -1) # (B*C, T)
|
|
||||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
|
||||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
|
||||||
|
|
||||||
def forward(self, mel_spec):
|
|
||||||
x = self.vocoder(mel_spec)
|
|
||||||
_, _, T_low = x.shape
|
|
||||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
|
||||||
|
|
||||||
remainder = T_low % self.hop_length
|
|
||||||
if remainder != 0:
|
|
||||||
x = F.pad(x, (0, self.hop_length - remainder))
|
|
||||||
|
|
||||||
mel = self._compute_mel(x)
|
|
||||||
residual = self.bwe_generator(mel)
|
|
||||||
skip = self.resampler(x)
|
|
||||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
|
||||||
|
|
||||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from comfy.ldm.flux.layers import EmbedND
|
|||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
|
|
||||||
|
|
||||||
|
|
||||||
def invert_slices(slices, length):
|
def invert_slices(slices, length):
|
||||||
@@ -859,267 +858,3 @@ class NextDiT(nn.Module):
|
|||||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||||
return -img
|
return -img
|
||||||
|
|
||||||
|
|
||||||
#############################################################################
|
|
||||||
# Pixel Space Decoder Components #
|
|
||||||
#############################################################################
|
|
||||||
|
|
||||||
def _modulate_shift_scale(x, shift, scale):
|
|
||||||
return x * (1 + scale) + shift
|
|
||||||
|
|
||||||
|
|
||||||
class PixelResBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
Residual block with AdaLN modulation, zero-initialised so it starts as
|
|
||||||
an identity at the beginning of training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channels: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
||||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
|
||||||
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
|
|
||||||
h = self.mlp(h)
|
|
||||||
return x + gate * h
|
|
||||||
|
|
||||||
|
|
||||||
class DCTFinalLayer(nn.Module):
|
|
||||||
"""Zero-initialised output projection (adopted from DiT)."""
|
|
||||||
|
|
||||||
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.linear(self.norm_final(x))
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleMLPAdaLN(nn.Module):
|
|
||||||
"""
|
|
||||||
Small MLP decoder head for the pixel-space variant.
|
|
||||||
|
|
||||||
Takes per-patch pixel values and a per-patch conditioning vector from the
|
|
||||||
transformer backbone and predicts the denoised pixel values.
|
|
||||||
|
|
||||||
x : [B*N, P^2, C] – noisy pixel values per patch position
|
|
||||||
c : [B*N, dim] – backbone hidden state per patch (conditioning)
|
|
||||||
→ [B*N, P^2, C]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
model_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
z_channels: int,
|
|
||||||
num_res_blocks: int,
|
|
||||||
max_freqs: int = 8,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
operations=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
# Project backbone hidden state → per-patch conditioning
|
|
||||||
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
# Input projection with DCT positional encoding
|
|
||||||
self.input_embedder = NerfEmbedder(
|
|
||||||
in_channels=in_channels,
|
|
||||||
hidden_size_input=model_channels,
|
|
||||||
max_freqs=max_freqs,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=operations,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Residual blocks
|
|
||||||
self.res_blocks = nn.ModuleList([
|
|
||||||
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Output projection
|
|
||||||
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
|
||||||
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
|
||||||
original_dtype = x.dtype
|
|
||||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
|
||||||
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
|
||||||
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
|
||||||
x = x.to(weight_dtype)
|
|
||||||
for block in self.res_blocks:
|
|
||||||
x = block(x, y)
|
|
||||||
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
|
|
||||||
|
|
||||||
|
|
||||||
#############################################################################
|
|
||||||
# NextDiT – Pixel Space #
|
|
||||||
#############################################################################
|
|
||||||
|
|
||||||
class NextDiTPixelSpace(NextDiT):
|
|
||||||
"""
|
|
||||||
Pixel-space variant of NextDiT.
|
|
||||||
|
|
||||||
Identical transformer backbone to NextDiT, but the output head is replaced
|
|
||||||
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
|
|
||||||
per patch rather than a single affine projection.
|
|
||||||
|
|
||||||
Key differences vs NextDiT:
|
|
||||||
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
|
|
||||||
• ``_forward`` stores the raw patchified pixel values before the backbone
|
|
||||||
embedding and feeds them to ``dec_net`` together with the per-patch
|
|
||||||
backbone hidden states.
|
|
||||||
• Supports optional x0 prediction via ``use_x0``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
# decoder-specific
|
|
||||||
decoder_hidden_size: int = 3840,
|
|
||||||
decoder_num_res_blocks: int = 4,
|
|
||||||
decoder_max_freqs: int = 8,
|
|
||||||
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
|
|
||||||
use_x0: bool = False,
|
|
||||||
# all NextDiT args forwarded unchanged
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# Remove the latent-space final layer – not used in pixel space
|
|
||||||
del self.final_layer
|
|
||||||
|
|
||||||
patch_size = kwargs.get("patch_size", 2)
|
|
||||||
in_channels = kwargs.get("in_channels", 4)
|
|
||||||
dim = kwargs.get("dim", 4096)
|
|
||||||
|
|
||||||
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
|
|
||||||
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
|
|
||||||
|
|
||||||
self.dec_net = SimpleMLPAdaLN(
|
|
||||||
in_channels=dec_in_ch,
|
|
||||||
model_channels=decoder_hidden_size,
|
|
||||||
out_channels=dec_in_ch,
|
|
||||||
z_channels=dim,
|
|
||||||
num_res_blocks=decoder_num_res_blocks,
|
|
||||||
max_freqs=decoder_max_freqs,
|
|
||||||
dtype=kwargs.get("dtype"),
|
|
||||||
device=kwargs.get("device"),
|
|
||||||
operations=kwargs.get("operations"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_x0:
|
|
||||||
self.register_buffer("__x0__", torch.tensor([]))
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
|
|
||||||
# with the pixel-space dec_net decoder.
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
|
||||||
omni = len(ref_latents) > 0
|
|
||||||
if omni:
|
|
||||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
|
||||||
|
|
||||||
t = 1.0 - timesteps
|
|
||||||
cap_feats = context
|
|
||||||
cap_mask = attention_mask
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
|
||||||
|
|
||||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
|
||||||
adaln_input = t
|
|
||||||
|
|
||||||
if self.clip_text_pooled_proj is not None:
|
|
||||||
pooled = kwargs.get("clip_text_pooled", None)
|
|
||||||
if pooled is not None:
|
|
||||||
pooled = self.clip_text_pooled_proj(pooled)
|
|
||||||
else:
|
|
||||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
|
||||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
|
||||||
|
|
||||||
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
|
|
||||||
pH = pW = self.patch_size
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
pixel_patches = (
|
|
||||||
x.view(B, C, H // pH, pH, W // pW, pW)
|
|
||||||
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
|
|
||||||
.flatten(3) # [B, Ht, Wt, pH*pW*C]
|
|
||||||
.flatten(1, 2) # [B, N, pH*pW*C]
|
|
||||||
)
|
|
||||||
N = pixel_patches.shape[1]
|
|
||||||
# decoder sees one token per patch: [B*N, 1, P^2*C]
|
|
||||||
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
|
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
|
||||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
|
|
||||||
x, cap_feats, cap_mask, adaln_input, num_tokens,
|
|
||||||
ref_latents=ref_latents, ref_contexts=ref_contexts,
|
|
||||||
siglip_feats=siglip_feats, transformer_options=transformer_options
|
|
||||||
)
|
|
||||||
freqs_cis = freqs_cis.to(img.device)
|
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.layers)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
img_input = img
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
transformer_options["block_index"] = i
|
|
||||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
|
||||||
if "double_block" in patches:
|
|
||||||
for p in patches["double_block"]:
|
|
||||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
|
||||||
if "img" in out:
|
|
||||||
img[:, cap_size[0]:] = out["img"]
|
|
||||||
if "txt" in out:
|
|
||||||
img[:, :cap_size[0]] = out["txt"]
|
|
||||||
|
|
||||||
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
|
|
||||||
# img may have padding tokens beyond N; only the first N are real image patches
|
|
||||||
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
|
|
||||||
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
|
|
||||||
|
|
||||||
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
|
|
||||||
output = output.reshape(B, N, -1) # [B, N, P^2*C]
|
|
||||||
|
|
||||||
# prepend zero cap placeholder so unpatchify indexing works unchanged
|
|
||||||
cap_placeholder = torch.zeros(
|
|
||||||
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
|
|
||||||
)
|
|
||||||
img_out = self.unpatchify(
|
|
||||||
torch.cat([cap_placeholder, output], dim=1),
|
|
||||||
img_size, cap_size, return_tensor=x_is_tensor
|
|
||||||
)[:, :, :h, :w]
|
|
||||||
|
|
||||||
return -img_out
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
|
||||||
# _forward returns neg_x0 = -x0 (negated decoder output).
|
|
||||||
#
|
|
||||||
# Reference inference (working_inference_reference.py):
|
|
||||||
# out = _forward(img, t) # = -x0
|
|
||||||
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
|
|
||||||
# img += (t_prev - t_curr) * pred # Euler step
|
|
||||||
#
|
|
||||||
# ComfyUI's Euler sampler does the same:
|
|
||||||
# x_next = x + (sigma_next - sigma) * model_output
|
|
||||||
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
|
|
||||||
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
||||||
self._forward,
|
|
||||||
self,
|
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
|
||||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
|
||||||
|
|
||||||
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
|
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ import comfy.patcher_extension
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
from ..sdpose import HeatmapHead
|
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Any module where forward() takes timestep embeddings as a second argument.
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
@@ -443,7 +441,6 @@ class UNetModel(nn.Module):
|
|||||||
disable_temporal_crossattention=False,
|
disable_temporal_crossattention=False,
|
||||||
max_ddpm_temb_period=10000,
|
max_ddpm_temb_period=10000,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
heatmap_head=False,
|
|
||||||
device=None,
|
device=None,
|
||||||
operations=ops,
|
operations=ops,
|
||||||
):
|
):
|
||||||
@@ -830,9 +827,6 @@ class UNetModel(nn.Module):
|
|||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
||||||
if heatmap_head:
|
|
||||||
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
|
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
|
|||||||
@@ -1,130 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from scipy.ndimage import gaussian_filter
|
|
||||||
|
|
||||||
class HeatmapHead(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=640,
|
|
||||||
out_channels=133,
|
|
||||||
input_size=(768, 1024),
|
|
||||||
heatmap_scale=4,
|
|
||||||
deconv_out_channels=(640,),
|
|
||||||
deconv_kernel_sizes=(4,),
|
|
||||||
conv_out_channels=(640,),
|
|
||||||
conv_kernel_sizes=(1,),
|
|
||||||
final_layer_kernel_size=1,
|
|
||||||
device=None, dtype=None, operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
|
|
||||||
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
|
|
||||||
|
|
||||||
# Deconv layers
|
|
||||||
if deconv_out_channels:
|
|
||||||
deconv_layers = []
|
|
||||||
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
|
|
||||||
if kernel_size == 4:
|
|
||||||
padding, output_padding = 1, 0
|
|
||||||
elif kernel_size == 3:
|
|
||||||
padding, output_padding = 1, 1
|
|
||||||
elif kernel_size == 2:
|
|
||||||
padding, output_padding = 0, 0
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported kernel size {kernel_size}')
|
|
||||||
|
|
||||||
deconv_layers.extend([
|
|
||||||
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
|
|
||||||
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
|
|
||||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
|
||||||
torch.nn.SiLU(inplace=True)
|
|
||||||
])
|
|
||||||
in_channels = out_ch
|
|
||||||
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
|
|
||||||
else:
|
|
||||||
self.deconv_layers = torch.nn.Identity()
|
|
||||||
|
|
||||||
# Conv layers
|
|
||||||
if conv_out_channels:
|
|
||||||
conv_layers = []
|
|
||||||
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
|
|
||||||
padding = (kernel_size - 1) // 2
|
|
||||||
conv_layers.extend([
|
|
||||||
operations.Conv2d(in_channels, out_ch, kernel_size,
|
|
||||||
stride=1, padding=padding, device=device, dtype=dtype),
|
|
||||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
|
||||||
torch.nn.SiLU(inplace=True)
|
|
||||||
])
|
|
||||||
in_channels = out_ch
|
|
||||||
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
|
||||||
else:
|
|
||||||
self.conv_layers = torch.nn.Identity()
|
|
||||||
|
|
||||||
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x): # Decode heatmaps to keypoints
|
|
||||||
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
|
|
||||||
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
|
|
||||||
B, K, H, W = heatmaps_np.shape
|
|
||||||
|
|
||||||
batch_keypoints = []
|
|
||||||
batch_scores = []
|
|
||||||
|
|
||||||
for b in range(B):
|
|
||||||
hm = heatmaps_np[b].copy() # (K, H, W)
|
|
||||||
|
|
||||||
# --- vectorised argmax ---
|
|
||||||
flat = hm.reshape(K, -1)
|
|
||||||
idx = np.argmax(flat, axis=1)
|
|
||||||
scores = flat[np.arange(K), idx].copy()
|
|
||||||
y_locs, x_locs = np.unravel_index(idx, (H, W))
|
|
||||||
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
|
|
||||||
invalid = scores <= 0.
|
|
||||||
keypoints[invalid] = -1
|
|
||||||
|
|
||||||
# --- DARK sub-pixel refinement (UDP) ---
|
|
||||||
# 1. Gaussian blur with max-preserving normalisation
|
|
||||||
border = 5 # (kernel-1)//2 for kernel=11
|
|
||||||
for k in range(K):
|
|
||||||
origin_max = np.max(hm[k])
|
|
||||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
|
||||||
dr[border:-border, border:-border] = hm[k].copy()
|
|
||||||
dr = gaussian_filter(dr, sigma=2.0)
|
|
||||||
hm[k] = dr[border:-border, border:-border].copy()
|
|
||||||
cur_max = np.max(hm[k])
|
|
||||||
if cur_max > 0:
|
|
||||||
hm[k] *= origin_max / cur_max
|
|
||||||
# 2. Log-space for Taylor expansion
|
|
||||||
np.clip(hm, 1e-3, 50., hm)
|
|
||||||
np.log(hm, hm)
|
|
||||||
# 3. Hessian-based Newton step
|
|
||||||
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
|
|
||||||
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
|
|
||||||
index += (W + 2) * (H + 2) * np.arange(0, K)
|
|
||||||
index = index.astype(int).reshape(-1, 1)
|
|
||||||
i_ = hm_pad[index]
|
|
||||||
ix1 = hm_pad[index + 1]
|
|
||||||
iy1 = hm_pad[index + W + 2]
|
|
||||||
ix1y1 = hm_pad[index + W + 3]
|
|
||||||
ix1_y1_ = hm_pad[index - W - 3]
|
|
||||||
ix1_ = hm_pad[index - 1]
|
|
||||||
iy1_ = hm_pad[index - 2 - W]
|
|
||||||
dx = 0.5 * (ix1 - ix1_)
|
|
||||||
dy = 0.5 * (iy1 - iy1_)
|
|
||||||
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
|
|
||||||
dxx = ix1 - 2 * i_ + ix1_
|
|
||||||
dyy = iy1 - 2 * i_ + iy1_
|
|
||||||
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
|
||||||
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
|
|
||||||
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
|
||||||
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
|
|
||||||
|
|
||||||
# --- restore to input image space ---
|
|
||||||
keypoints = keypoints * self.scale_factor
|
|
||||||
keypoints[invalid] = -1
|
|
||||||
|
|
||||||
batch_keypoints.append(keypoints)
|
|
||||||
batch_scores.append(scores)
|
|
||||||
|
|
||||||
return batch_keypoints, batch_scores
|
|
||||||
@@ -1621,118 +1621,3 @@ class HumoWanModel(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class SCAILWanModel(WanModel):
|
|
||||||
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
|
||||||
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
|
||||||
|
|
||||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
|
||||||
|
|
||||||
if reference_latent is not None:
|
|
||||||
x = torch.cat((reference_latent, x), dim=2)
|
|
||||||
|
|
||||||
# embeddings
|
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
|
||||||
grid_sizes = x.shape[2:]
|
|
||||||
transformer_options["grid_sizes"] = grid_sizes
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
scail_pose_seq_len = 0
|
|
||||||
if pose_latents is not None:
|
|
||||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
|
||||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
|
||||||
scail_pose_seq_len = scail_x.shape[1]
|
|
||||||
x = torch.cat([x, scail_x], dim=1)
|
|
||||||
del scail_x
|
|
||||||
|
|
||||||
# time embeddings
|
|
||||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
|
||||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
|
||||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
|
||||||
|
|
||||||
# context
|
|
||||||
context = self.text_embedding(context)
|
|
||||||
|
|
||||||
context_img_len = None
|
|
||||||
if clip_fea is not None:
|
|
||||||
if self.img_emb is not None:
|
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
|
||||||
context = torch.cat([context_clip, context], dim=1)
|
|
||||||
context_img_len = clip_fea.shape[-2]
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
|
||||||
return out
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
|
||||||
x = out["img"]
|
|
||||||
else:
|
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
# head
|
|
||||||
x = self.head(x, e)
|
|
||||||
|
|
||||||
if scail_pose_seq_len > 0:
|
|
||||||
x = x[:, :-scail_pose_seq_len]
|
|
||||||
|
|
||||||
# unpatchify
|
|
||||||
x = self.unpatchify(x, grid_sizes)
|
|
||||||
|
|
||||||
if reference_latent is not None:
|
|
||||||
x = x[:, :, reference_latent.shape[2]:]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
|
||||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
if pose_latents is None:
|
|
||||||
return main_freqs
|
|
||||||
|
|
||||||
ref_t_patches = 0
|
|
||||||
if reference_latent is not None:
|
|
||||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
|
||||||
|
|
||||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
|
||||||
|
|
||||||
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
|
||||||
h_scale = h / H_pose
|
|
||||||
w_scale = w / W_pose
|
|
||||||
|
|
||||||
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
|
||||||
h_shift = (h_scale - 1) / 2
|
|
||||||
w_shift = (w_scale - 1) / 2
|
|
||||||
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
|
||||||
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
|
||||||
|
|
||||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
|
||||||
bs, c, t, h, w = x.shape
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
|
||||||
|
|
||||||
if pose_latents is not None:
|
|
||||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
|
||||||
|
|
||||||
t_len = t
|
|
||||||
if time_dim_concat is not None:
|
|
||||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
|
||||||
x = torch.cat([x, time_dim_concat], dim=2)
|
|
||||||
t_len = x.shape[2]
|
|
||||||
|
|
||||||
reference_latent = None
|
|
||||||
if "reference_latent" in kwargs:
|
|
||||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
|
||||||
t_len += reference_latent.shape[2]
|
|
||||||
|
|
||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
|
||||||
|
|||||||
@@ -459,7 +459,6 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
image_channels=3,
|
image_channels=3,
|
||||||
conv_out_channels=3,
|
|
||||||
dropout=0.0):
|
dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -475,7 +474,7 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales, self.temperal_downsample, dropout)
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
@@ -485,7 +484,7 @@ class WanVAE(nn.Module):
|
|||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.encoder)
|
feat_map = [None] * count_conv3d(self.decoder)
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
|
|||||||
@@ -337,7 +337,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|||||||
@@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
|
|||||||
|
|
||||||
return dest_views
|
return dest_views
|
||||||
|
|
||||||
aimdo_enabled = False
|
aimdo_allocator = None
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ class ModelType(Enum):
|
|||||||
FLUX = 8
|
FLUX = 8
|
||||||
IMG_TO_IMG = 9
|
IMG_TO_IMG = 9
|
||||||
FLOW_COSMOS = 10
|
FLOW_COSMOS = 10
|
||||||
IMG_TO_IMG_FLOW = 11
|
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@@ -109,8 +108,6 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.FLOW_COSMOS:
|
elif model_type == ModelType.FLOW_COSMOS:
|
||||||
c = comfy.model_sampling.COSMOS_RFLOW
|
c = comfy.model_sampling.COSMOS_RFLOW
|
||||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
|
||||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@@ -181,7 +178,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype_inference()
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
device = xc.device
|
device = xc.device
|
||||||
@@ -218,13 +218,6 @@ class BaseModel(torch.nn.Module):
|
|||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
def get_dtype_inference(self):
|
|
||||||
dtype = self.get_dtype()
|
|
||||||
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
return dtype
|
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -379,7 +372,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes += shape
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype_inference()
|
dtype = self.get_dtype()
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
@@ -925,25 +920,6 @@ class Flux(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class LongCatImage(Flux):
|
|
||||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
|
||||||
transformer_options = transformer_options.copy()
|
|
||||||
rope_opts = transformer_options.get("rope_options", {})
|
|
||||||
rope_opts = dict(rope_opts)
|
|
||||||
rope_opts.setdefault("shift_t", 1.0)
|
|
||||||
rope_opts.setdefault("shift_y", 512.0)
|
|
||||||
rope_opts.setdefault("shift_x", 512.0)
|
|
||||||
transformer_options["rope_options"] = rope_opts
|
|
||||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
|
||||||
out = super().extra_conds(**kwargs)
|
|
||||||
out.pop('guidance', None)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Flux2(Flux):
|
class Flux2(Flux):
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@@ -993,10 +969,6 @@ class LTXV(BaseModel):
|
|||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
|
||||||
if guide_attention_entries is not None:
|
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
@@ -1014,14 +986,10 @@ class LTXAV(BaseModel):
|
|||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
device = kwargs["device"]
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
@@ -1049,10 +1017,6 @@ class LTXAV(BaseModel):
|
|||||||
if latent_shapes is not None:
|
if latent_shapes is not None:
|
||||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
|
||||||
if guide_attention_entries is not None:
|
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
@@ -1201,7 +1165,7 @@ class Anima(BaseModel):
|
|||||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
if torch.is_inference_mode_enabled(): # if not we are training
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||||
else:
|
else:
|
||||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
@@ -1263,11 +1227,6 @@ class Lumina2(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class ZImagePixelSpace(Lumina2):
|
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
|
||||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
|
||||||
self.memory_usage_factor_conds = ("ref_latents",)
|
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
@@ -1501,50 +1460,6 @@ class WAN22(WAN21):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
class WAN21_FlowRVS(WAN21):
|
|
||||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
|
||||||
model_config.unet_config["model_type"] = "t2v"
|
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
|
||||||
self.image_to_video = image_to_video
|
|
||||||
|
|
||||||
class WAN21_SCAIL(WAN21):
|
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
|
||||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
|
||||||
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
|
||||||
self.image_to_video = image_to_video
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
|
||||||
out = super().extra_conds(**kwargs)
|
|
||||||
|
|
||||||
reference_latents = kwargs.get("reference_latents", None)
|
|
||||||
if reference_latents is not None:
|
|
||||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
|
||||||
ref_mask = torch.ones_like(ref_latent[:, :4])
|
|
||||||
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
|
||||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
|
||||||
|
|
||||||
pose_latents = kwargs.get("pose_video_latent", None)
|
|
||||||
if pose_latents is not None:
|
|
||||||
pose_latents = self.process_latent_in(pose_latents)
|
|
||||||
pose_mask = torch.ones_like(pose_latents[:, :4])
|
|
||||||
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
|
||||||
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def extra_conds_shapes(self, **kwargs):
|
|
||||||
out = {}
|
|
||||||
ref_latents = kwargs.get("reference_latents", None)
|
|
||||||
if ref_latents is not None:
|
|
||||||
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
|
||||||
|
|
||||||
pose_latents = kwargs.get("pose_video_latent", None)
|
|
||||||
if pose_latents is not None:
|
|
||||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@@ -279,8 +279,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
|
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
@@ -423,7 +421,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "lumina2"
|
dit_config["image_model"] = "lumina2"
|
||||||
dit_config["patch_size"] = 2
|
dit_config["patch_size"] = 2
|
||||||
@@ -464,29 +462,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if sig_weight is not None:
|
if sig_weight is not None:
|
||||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||||
|
|
||||||
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
|
|
||||||
if dec_cond_key in state_dict_keys: # pixel-space variant
|
|
||||||
dit_config["image_model"] = "zimage_pixel"
|
|
||||||
# patch_size and in_channels are derived from x_embedder:
|
|
||||||
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
|
|
||||||
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
|
|
||||||
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
|
|
||||||
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
|
|
||||||
# patch_size: infer from decoder final layer output matching x_embedder input
|
|
||||||
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
|
|
||||||
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
|
|
||||||
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
|
|
||||||
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
|
|
||||||
dit_config["in_channels"] = 3
|
|
||||||
dit_config["decoder_in_channels"] = dec_in_ch
|
|
||||||
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
|
|
||||||
dit_config["decoder_num_res_blocks"] = count_blocks(
|
|
||||||
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
|
|
||||||
)
|
|
||||||
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
|
|
||||||
if '{}__x0__'.format(key_prefix) in state_dict_keys:
|
|
||||||
dit_config["use_x0"] = True
|
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
@@ -521,8 +496,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "humo"
|
dit_config["model_type"] = "humo"
|
||||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "animate"
|
dit_config["model_type"] = "animate"
|
||||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
|
||||||
dit_config["model_type"] = "scail"
|
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
@@ -536,9 +509,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if ref_conv_weight is not None:
|
if ref_conv_weight is not None:
|
||||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||||
|
|
||||||
if metadata is not None and "config" in metadata:
|
|
||||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
@@ -556,7 +526,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||||
|
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "hunyuan3d2_1"
|
dit_config["image_model"] = "hunyuan3d2_1"
|
||||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||||
@@ -821,10 +792,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
unet_config["use_temporal_resblock"] = False
|
unet_config["use_temporal_resblock"] = False
|
||||||
unet_config["use_temporal_attention"] = False
|
unet_config["use_temporal_attention"] = False
|
||||||
|
|
||||||
heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix)
|
|
||||||
if heatmap_key in state_dict_keys:
|
|
||||||
unet_config["heatmap_head"] = True
|
|
||||||
|
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||||
@@ -1045,7 +1012,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
|
|
||||||
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
||||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64,
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
|
||||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
@@ -1077,13 +1044,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||||
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
|
|
||||||
n_layers = count_blocks(state_dict, 'layers.{}.')
|
|
||||||
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
|
|
||||||
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
|
|
||||||
for k in state_dict: # For zeta chroma
|
|
||||||
if k not in sd_map:
|
|
||||||
sd_map[k] = k
|
|
||||||
elif 'x_embedder.weight' in state_dict: #Flux
|
elif 'x_embedder.weight' in state_dict: #Flux
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ import comfy.memory_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
|
||||||
|
import comfy_aimdo.torch
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@@ -177,14 +180,6 @@ def is_ixuca():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_wsl():
|
|
||||||
version = platform.uname().release
|
|
||||||
if version.endswith("-Microsoft"):
|
|
||||||
return True
|
|
||||||
elif version.endswith("microsoft-standard-WSL2"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@@ -355,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
@@ -383,7 +378,7 @@ try:
|
|||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if rocm_version >= (7, 0):
|
if rocm_version >= (7, 0):
|
||||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||||
@@ -636,6 +631,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
ram_to_free = ram_required - get_free_ram()
|
ram_to_free = ram_required - get_free_ram()
|
||||||
|
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
@@ -796,8 +792,6 @@ def archive_model_dtypes(model):
|
|||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
for param_name, param in module.named_parameters(recurse=False):
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||||
for buf_name, buf in module.named_buffers(recurse=False):
|
|
||||||
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
@@ -830,14 +824,11 @@ def unet_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def unet_inital_load_device(parameters, dtype):
|
def unet_inital_load_device(parameters, dtype):
|
||||||
cpu_dev = torch.device("cpu")
|
|
||||||
if comfy.memory_management.aimdo_enabled:
|
|
||||||
return cpu_dev
|
|
||||||
|
|
||||||
torch_dev = get_torch_device()
|
torch_dev = get_torch_device()
|
||||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||||
return torch_dev
|
return torch_dev
|
||||||
|
|
||||||
|
cpu_dev = torch.device("cpu")
|
||||||
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
@@ -845,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
|
|
||||||
mem_dev = get_free_memory(torch_dev)
|
mem_dev = get_free_memory(torch_dev)
|
||||||
mem_cpu = get_free_memory(cpu_dev)
|
mem_cpu = get_free_memory(cpu_dev)
|
||||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
||||||
return torch_dev
|
return torch_dev
|
||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
@@ -948,9 +939,6 @@ def text_encoder_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||||
if comfy.memory_management.aimdo_enabled:
|
|
||||||
return offload_device
|
|
||||||
|
|
||||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||||
return offload_device
|
return offload_device
|
||||||
|
|
||||||
@@ -1133,6 +1121,7 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
synchronize()
|
synchronize()
|
||||||
del STREAM_CAST_BUFFERS[offload_stream]
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
del cast_buffer
|
del cast_buffer
|
||||||
|
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
with wf_context:
|
with wf_context:
|
||||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
@@ -1211,6 +1200,43 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
|
if hasattr(weight, "_v"):
|
||||||
|
#Unexpected usage patterns. There is no reason these don't work but they
|
||||||
|
#have no testing and no callers do this.
|
||||||
|
assert r is None
|
||||||
|
assert stream is None
|
||||||
|
|
||||||
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||||
|
|
||||||
|
if dtype is None:
|
||||||
|
dtype = weight._model_dtype
|
||||||
|
|
||||||
|
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||||
|
if signature is not None:
|
||||||
|
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||||
|
v_tensor = weight._v_tensor
|
||||||
|
else:
|
||||||
|
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||||
|
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||||
|
weight._v_tensor = v_tensor
|
||||||
|
weight._v_signature = signature
|
||||||
|
#Send it over
|
||||||
|
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return v_tensor.to(dtype=dtype)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||||
|
#Offloaded casting could skip this, however it would make the quantizations
|
||||||
|
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||||
|
#that would happen in regular flow to make offload deterministic.
|
||||||
|
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||||
|
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||||
|
weight = cast_buffer
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
@@ -1666,16 +1692,12 @@ def lora_compute_dtype(device):
|
|||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
if cpu_mode():
|
|
||||||
return
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
if cpu_mode():
|
|
||||||
return
|
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|||||||
@@ -241,7 +241,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.backup_buffers = {}
|
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
self.weight_wrapper_patches = {}
|
self.weight_wrapper_patches = {}
|
||||||
@@ -272,7 +271,6 @@ class ModelPatcher:
|
|||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@@ -307,30 +305,10 @@ class ModelPatcher:
|
|||||||
return self.model.lowvram_patch_counter
|
return self.model.lowvram_patch_counter
|
||||||
|
|
||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
|
return comfy.model_management.get_free_memory(device)
|
||||||
#the vast majority of setups a little bit of offloading on the giant model more
|
|
||||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
|
||||||
aimdo_mem = 0
|
|
||||||
if comfy.memory_management.aimdo_enabled:
|
|
||||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
|
||||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
|
||||||
|
|
||||||
def get_clone_model_override(self):
|
def clone(self):
|
||||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
|
|
||||||
def clone(self, disable_dynamic=False, model_override=None):
|
|
||||||
class_ = self.__class__
|
|
||||||
if self.is_dynamic() and disable_dynamic:
|
|
||||||
class_ = ModelPatcher
|
|
||||||
if model_override is None:
|
|
||||||
if self.cached_patcher_init is None:
|
|
||||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
|
||||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
|
||||||
model_override = temp_model_patcher.get_clone_model_override()
|
|
||||||
if model_override is None:
|
|
||||||
model_override = self.get_clone_model_override()
|
|
||||||
|
|
||||||
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@@ -339,12 +317,13 @@ class ModelPatcher:
|
|||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||||
|
n.backup = self.backup
|
||||||
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
n.pinned = self.pinned
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
|
|
||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
n.attachments = {}
|
n.attachments = {}
|
||||||
for k in self.attachments:
|
for k in self.attachments:
|
||||||
@@ -383,8 +362,6 @@ class ModelPatcher:
|
|||||||
n.is_clip = self.is_clip
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
n.cached_patcher_init = self.cached_patcher_init
|
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
@@ -429,16 +406,13 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
def disable_model_cfg1_optimization(self):
|
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
self.disable_model_cfg1_optimization()
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
@@ -705,7 +679,7 @@ class ModelPatcher:
|
|||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self, for_dynamic=False, default_device=None):
|
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
default = False
|
default = False
|
||||||
@@ -733,13 +707,8 @@ class ModelPatcher:
|
|||||||
return 0
|
return 0
|
||||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||||
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
|
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||||
# Non-dynamic: prioritize by module offload cost.
|
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||||
if for_dynamic:
|
|
||||||
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
|
|
||||||
else:
|
|
||||||
sort_criteria = (module_offload_mem,)
|
|
||||||
loading.append(sort_criteria + (module_mem, n, m, params))
|
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
@@ -1447,9 +1416,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||||
|
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||||
|
#and non-dynamic patchers.
|
||||||
|
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||||
|
del self.model.model_loaded_weight_memory
|
||||||
if not hasattr(self.model, "dynamic_vbars"):
|
if not hasattr(self.model, "dynamic_vbars"):
|
||||||
self.model.dynamic_vbars = {}
|
self.model.dynamic_vbars = {}
|
||||||
self.non_dynamic_delegate_model = None
|
|
||||||
assert load_device is not None
|
assert load_device is not None
|
||||||
|
|
||||||
def is_dynamic(self):
|
def is_dynamic(self):
|
||||||
@@ -1469,7 +1441,15 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
vbar = self._vbar_get()
|
vbar = self._vbar_get()
|
||||||
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
|
if vbar is None:
|
||||||
|
return 0
|
||||||
|
return vbar.loaded_size()
|
||||||
|
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||||
|
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||||
|
#would all be avaiable for inference use.
|
||||||
|
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||||
|
|
||||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||||
|
|
||||||
@@ -1504,7 +1484,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
num_patches = 0
|
num_patches = 0
|
||||||
allocated_size = 0
|
allocated_size = 0
|
||||||
self.model.model_loaded_weight_memory = 0
|
|
||||||
|
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
@@ -1513,11 +1492,15 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if vbar is not None:
|
if vbar is not None:
|
||||||
vbar.prioritize()
|
vbar.prioritize()
|
||||||
|
|
||||||
loading = self._load_list(for_dynamic=True, default_device=device_to)
|
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||||
loading.sort()
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
for x in loading:
|
for x in loading:
|
||||||
*_, module_mem, n, m, params = x
|
_, _, _, n, m, params = x
|
||||||
|
|
||||||
def set_dirty(item, dirty):
|
def set_dirty(item, dirty):
|
||||||
if dirty or not hasattr(item, "_v_signature"):
|
if dirty or not hasattr(item, "_v_signature"):
|
||||||
@@ -1555,9 +1538,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if key in self.backup:
|
if key in self.backup:
|
||||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
|
||||||
if weight is not None:
|
|
||||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@@ -1583,26 +1563,21 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for param in params:
|
for param in params:
|
||||||
key = key_param_name_to_key(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
if key not in self.backup:
|
weight.seed_key = key
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
set_dirty(weight, dirty)
|
||||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
|
geometry = weight
|
||||||
casted_weight = weight.to(dtype=model_dtype, device=device_to)
|
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||||
comfy.utils.set_attr_param(self.model, key, casted_weight)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
|
weight_size = geometry.numel() * geometry.element_size()
|
||||||
|
if vbar is not None and not hasattr(weight, "_v"):
|
||||||
|
weight._v = vbar.alloc(weight_size)
|
||||||
|
weight._model_dtype = model_dtype
|
||||||
|
allocated_size += weight_size
|
||||||
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
for key, buf in self.model.named_buffers(recurse=True):
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
if key not in self.backup_buffers:
|
|
||||||
self.backup_buffers[key] = buf
|
|
||||||
module, buf_name = comfy.utils.resolve_attr(self.model, key)
|
|
||||||
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
|
|
||||||
casted_buf = buf.to(dtype=model_dtype, device=device_to)
|
|
||||||
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
|
|
||||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
|
||||||
|
|
||||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
|
||||||
|
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
@@ -1618,23 +1593,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
assert self.load_device != torch.device("cpu")
|
assert self.load_device != torch.device("cpu")
|
||||||
|
|
||||||
vbar = self._vbar_get()
|
vbar = self._vbar_get()
|
||||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
if freed < memory_to_free:
|
|
||||||
for key in list(self.backup.keys()):
|
|
||||||
bk = self.backup.pop(key)
|
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
|
||||||
for key in list(self.backup_buffers.keys()):
|
|
||||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
|
||||||
freed += self.model.model_loaded_weight_memory
|
|
||||||
self.model.model_loaded_weight_memory = 0
|
|
||||||
|
|
||||||
return freed
|
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
*_, m, _ = x
|
_, _, _, _, m, _ = x
|
||||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||||
if ram_to_unload <= 0:
|
if ram_to_unload <= 0:
|
||||||
return
|
return
|
||||||
@@ -1656,6 +1620,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
bk = self.backup[k]
|
||||||
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
@@ -1687,10 +1656,4 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_non_dynamic_delegate(self):
|
|
||||||
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
|
|
||||||
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
|
|
||||||
return model_patcher
|
|
||||||
|
|
||||||
|
|
||||||
CoreModelPatcher = ModelPatcher
|
CoreModelPatcher = ModelPatcher
|
||||||
|
|||||||
@@ -83,16 +83,6 @@ class IMG_TO_IMG(X0):
|
|||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
class IMG_TO_IMG_FLOW(CONST):
|
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
|
||||||
return latent_image
|
|
||||||
|
|
||||||
def inverse_noise_scaling(self, sigma, latent):
|
|
||||||
return 1.0 - latent
|
|
||||||
|
|
||||||
class COSMOS_RFLOW:
|
class COSMOS_RFLOW:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = (sigma / (sigma + 1))
|
sigma = (sigma / (sigma + 1))
|
||||||
|
|||||||
79
comfy/ops.py
79
comfy/ops.py
@@ -19,8 +19,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
import comfy.rmsnorm
|
||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.pinned_memory
|
import comfy.pinned_memory
|
||||||
@@ -79,22 +80,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||||
|
|
||||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
|
||||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
|
||||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
|
||||||
#If you are a custom node author reading this, please move your layer to the GPU
|
|
||||||
#or declare your ModelPatcher as CPU in the first place.
|
|
||||||
if comfy.model_management.is_device_cpu(device):
|
|
||||||
weight = s.weight.to(dtype=dtype, copy=True)
|
|
||||||
if isinstance(weight, QuantizedTensor):
|
|
||||||
weight = weight.dequantize()
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
|
||||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
|
||||||
return weight, bias, (None, None, None)
|
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
|
|
||||||
@@ -182,15 +168,17 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
x = to_dequant(x, dtype)
|
x = to_dequant(x, dtype)
|
||||||
if not resident and lowvram_fn is not None:
|
if not resident and lowvram_fn is not None:
|
||||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||||
|
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||||
x = lowvram_fn(x)
|
x = lowvram_fn(x)
|
||||||
if (want_requant and len(fns) == 0 or update_weight):
|
if (isinstance(orig, QuantizedTensor) and
|
||||||
|
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
if isinstance(orig, QuantizedTensor):
|
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
else:
|
if orig.dtype == dtype and len(fns) == 0:
|
||||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
#The layer actually wants our freshly saved QT
|
||||||
if want_requant and len(fns) == 0:
|
|
||||||
x = y
|
x = y
|
||||||
|
elif update_weight:
|
||||||
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||||
if update_weight:
|
if update_weight:
|
||||||
orig.copy_(y)
|
orig.copy_(y)
|
||||||
for f in fns:
|
for f in fns:
|
||||||
@@ -207,7 +195,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
@@ -225,7 +213,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@@ -284,8 +272,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
|||||||
return
|
return
|
||||||
os, weight_a, bias_a = offload_stream
|
os, weight_a, bias_a = offload_stream
|
||||||
device=None
|
device=None
|
||||||
#FIXME: This is really bad RTTI
|
#FIXME: This is not good RTTI
|
||||||
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
|
if not isinstance(weight_a, torch.Tensor):
|
||||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||||
device = weight_a
|
device = weight_a
|
||||||
if os is None:
|
if os is None:
|
||||||
@@ -309,7 +297,7 @@ class disable_weight_init:
|
|||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -330,7 +318,7 @@ class disable_weight_init:
|
|||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
@@ -475,7 +463,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@@ -487,7 +475,8 @@ class disable_weight_init:
|
|||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -630,8 +619,7 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if input.ndim != 2:
|
if input.ndim != 2:
|
||||||
return None
|
return None
|
||||||
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
|
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
@@ -675,30 +663,24 @@ class fp8_ops(manual_cast):
|
|||||||
|
|
||||||
CUBLAS_IS_AVAILABLE = False
|
CUBLAS_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from cublas_ops import CublasLinear, cublas_half_matmul
|
from cublas_ops import CublasLinear
|
||||||
CUBLAS_IS_AVAILABLE = True
|
CUBLAS_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if CUBLAS_IS_AVAILABLE:
|
if CUBLAS_IS_AVAILABLE:
|
||||||
class cublas_ops(manual_cast):
|
class cublas_ops(disable_weight_init):
|
||||||
class Linear(CublasLinear, manual_cast.Linear):
|
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
return super().forward(input)
|
||||||
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@@ -847,10 +829,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
else:
|
else:
|
||||||
sd = {}
|
sd = {}
|
||||||
|
|
||||||
if not hasattr(self, 'weight'):
|
|
||||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
|
||||||
return sd
|
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
sd["{}bias".format(prefix)] = self.bias
|
||||||
|
|
||||||
@@ -874,8 +852,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||||
x = self._forward(input, weight, bias)
|
x = self._forward(input, weight, bias)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
@@ -905,7 +883,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||||
|
|
||||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
|
||||||
|
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||||
|
|
||||||
# Reshape output back to 3D if input was 3D
|
# Reshape output back to 3D if input was 3D
|
||||||
if reshaped_3d:
|
if reshaped_3d:
|
||||||
|
|||||||
@@ -1,10 +1,57 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import numbers
|
||||||
|
import logging
|
||||||
|
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
logging.warning("Please update pytorch to use native RMSNorm")
|
||||||
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
if RMSNorm is None:
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape,
|
||||||
|
eps=1e-6,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return rms_norm(x, self.weight, self.eps)
|
||||||
|
|||||||
@@ -66,18 +66,6 @@ def convert_cond(cond):
|
|||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def cond_has_hooks(cond):
|
|
||||||
for c in cond:
|
|
||||||
temp = c[1]
|
|
||||||
if "hooks" in temp:
|
|
||||||
return True
|
|
||||||
if "control" in temp:
|
|
||||||
control = temp["control"]
|
|
||||||
extra_hooks = control.get_extra_hooks()
|
|
||||||
if len(extra_hooks) > 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_additional_models(conds, dtype):
|
def get_additional_models(conds, dtype):
|
||||||
"""loads additional models in conditioning"""
|
"""loads additional models in conditioning"""
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
|
|||||||
@@ -946,8 +946,6 @@ class CFGGuider:
|
|||||||
|
|
||||||
def inner_set_conds(self, conds):
|
def inner_set_conds(self, conds):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
|
|
||||||
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
|
|
||||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
|||||||
100
comfy/sd.py
100
comfy/sd.py
@@ -60,7 +60,6 @@ import comfy.text_encoders.jina_clip_2
|
|||||||
import comfy.text_encoders.newbie
|
import comfy.text_encoders.newbie
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@@ -204,7 +203,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@@ -233,8 +232,7 @@ class CLIP:
|
|||||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
|
||||||
#Match torch.float32 hardcode upcast in TE implemention
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
self.patcher.set_model_compute_dtype(torch.float32)
|
self.patcher.set_model_compute_dtype(torch.float32)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
@@ -268,9 +266,9 @@ class CLIP:
|
|||||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
self.tokenizer_options = {}
|
self.tokenizer_options = {}
|
||||||
|
|
||||||
def clone(self, disable_dynamic=False):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
|
n.patcher = self.patcher.clone()
|
||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
@@ -425,17 +423,6 @@ class CLIP:
|
|||||||
def get_key_patches(self):
|
def get_key_patches(self):
|
||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
|
||||||
self.cond_stage_model.reset_clip_options()
|
|
||||||
|
|
||||||
self.load_model(tokens)
|
|
||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
@@ -696,9 +683,8 @@ class VAE:
|
|||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||||
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
|
|
||||||
self.pad_channel_value = 1.0
|
self.pad_channel_value = 1.0
|
||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
@@ -1162,24 +1148,16 @@ class CLIPType(Enum):
|
|||||||
KANDINSKY5_IMAGE = 23
|
KANDINSKY5_IMAGE = 23
|
||||||
NEWBIE = 24
|
NEWBIE = 24
|
||||||
FLUX2 = 25
|
FLUX2 = 25
|
||||||
LONGCAT_IMAGE = 26
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
|
||||||
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
|
|
||||||
return clip.patcher
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
if model_options.get("custom_operations", None) is None:
|
if model_options.get("custom_operations", None) is None:
|
||||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||||
clip_data.append(sd)
|
clip_data.append(sd)
|
||||||
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
|
|
||||||
return clip
|
|
||||||
|
|
||||||
|
|
||||||
class TEModel(Enum):
|
class TEModel(Enum):
|
||||||
@@ -1204,7 +1182,6 @@ class TEModel(Enum):
|
|||||||
JINA_CLIP_2 = 19
|
JINA_CLIP_2 = 19
|
||||||
QWEN3_8B = 20
|
QWEN3_8B = 20
|
||||||
QWEN3_06B = 21
|
QWEN3_06B = 21
|
||||||
GEMMA_3_4B_VISION = 22
|
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@@ -1233,9 +1210,6 @@ def detect_te_model(sd):
|
|||||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_12B
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
if 'vision_model.embeddings.patch_embedding.weight' in sd:
|
|
||||||
return TEModel.GEMMA_3_4B_VISION
|
|
||||||
else:
|
|
||||||
return TEModel.GEMMA_3_4B
|
return TEModel.GEMMA_3_4B
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||||
@@ -1284,7 +1258,7 @@ def llama_detect(clip_data):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
@@ -1296,8 +1270,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else:
|
else:
|
||||||
if "text_projection" in clip_data[i]:
|
if "text_projection" in clip_data[i]:
|
||||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||||
if "lm_head.weight" in clip_data[i]:
|
|
||||||
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
|
|
||||||
|
|
||||||
tokenizer_data = {}
|
tokenizer_data = {}
|
||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
@@ -1363,14 +1335,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.GEMMA_3_4B_VISION:
|
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
|
||||||
elif te_model == TEModel.GEMMA_3_12B:
|
|
||||||
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
|
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||||
@@ -1382,9 +1346,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||||
elif clip_type == CLIPType.LONGCAT_IMAGE:
|
|
||||||
clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data))
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
@@ -1467,7 +1428,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
elif clip_type == CLIPType.LTXV:
|
elif clip_type == CLIPType.LTXV:
|
||||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.NEWBIE:
|
elif clip_type == CLIPType.NEWBIE:
|
||||||
@@ -1504,7 +1465,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@@ -1544,34 +1505,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if output_model and out[0] is not None:
|
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
|
||||||
if output_clip and out[1] is not None:
|
|
||||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
|
||||||
embedding_directory=embedding_directory,
|
|
||||||
model_options=model_options,
|
|
||||||
te_model_options=te_model_options,
|
|
||||||
disable_dynamic=disable_dynamic)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
|
||||||
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
|
|
||||||
embedding_directory=embedding_directory, output_model=False,
|
|
||||||
model_options=model_options,
|
|
||||||
te_model_options=te_model_options,
|
|
||||||
disable_dynamic=disable_dynamic)
|
|
||||||
return clip.patcher
|
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
@@ -1620,8 +1561,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
|
||||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@@ -1656,7 +1596,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
@@ -1672,7 +1612,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@@ -1756,8 +1696,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
|
||||||
if not model_management.is_device_cpu(offload_device):
|
if not model_management.is_device_cpu(offload_device):
|
||||||
model.to(offload_device)
|
model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||||
@@ -1766,13 +1705,12 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||||
return model_patcher
|
return model_patcher
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
|
|||||||
@@ -308,15 +308,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
|
||||||
if isinstance(tokens, dict):
|
|
||||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
|
||||||
else:
|
|
||||||
tokens_only = tokens
|
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
|
||||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
current_item = ""
|
current_item = ""
|
||||||
@@ -573,8 +564,6 @@ class SDTokenizer:
|
|||||||
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
||||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
min_length = kwargs.get("min_length", min_length)
|
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
if kwargs.get("disable_weights", self.disable_weights):
|
if kwargs.get("disable_weights", self.disable_weights):
|
||||||
parsed_weights = [(text, 1.0)]
|
parsed_weights = [(text, 1.0)]
|
||||||
@@ -674,9 +663,6 @@ class SDTokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
@@ -700,9 +686,6 @@ class SD1Tokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return getattr(self, self.clip).state_dict()
|
return getattr(self, self.clip).state_dict()
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
|
||||||
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
||||||
|
|
||||||
class SD1CheckpointClipModel(SDClipModel):
|
class SD1CheckpointClipModel(SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||||
@@ -739,6 +722,3 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return getattr(self, self.clip).load_sd(sd)
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|
||||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
|
||||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import comfy.text_encoders.kandinsky5
|
|||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@@ -526,8 +525,7 @@ class LotusD(SD20):
|
|||||||
}
|
}
|
||||||
|
|
||||||
unet_extra_config = {
|
unet_extra_config = {
|
||||||
"num_classes": 'sequential',
|
"num_classes": 'sequential'
|
||||||
"num_head_channels": 64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
@@ -1118,20 +1116,6 @@ class ZImage(Lumina2):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
class ZImagePixelSpace(ZImage):
|
|
||||||
unet_config = {
|
|
||||||
"image_model": "zimage_pixel",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Pixel-space model: no spatial compression, operates on raw RGB patches.
|
|
||||||
latent_format = latent_formats.ZImagePixelSpace
|
|
||||||
|
|
||||||
# Much lower memory than latent-space models (no VAE, small patches).
|
|
||||||
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
|
||||||
return model_base.ZImagePixelSpace(self, device=device)
|
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@@ -1272,26 +1256,6 @@ class WAN22_T2V(WAN21_T2V):
|
|||||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21_FlowRVS(WAN21_T2V):
|
|
||||||
unet_config = {
|
|
||||||
"image_model": "wan2.1",
|
|
||||||
"model_type": "flow_rvs",
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
|
||||||
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class WAN21_SCAIL(WAN21_T2V):
|
|
||||||
unet_config = {
|
|
||||||
"image_model": "wan2.1",
|
|
||||||
"model_type": "scail",
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
|
||||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1703,37 +1667,6 @@ class ACEStep15(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||||
|
|
||||||
|
|
||||||
class LongCatImage(supported_models_base.BASE):
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
unet_config = {
|
|
||||||
"image_model": "flux",
|
|
||||||
"guidance_embed": False,
|
|
||||||
"vec_in_dim": None,
|
|
||||||
"context_in_dim": 3584,
|
|
||||||
"txt_ids_dims": [1, 2],
|
|
||||||
}
|
|
||||||
|
|
||||||
sampling_settings = {
|
|
||||||
}
|
|
||||||
|
|
||||||
unet_extra_config = {}
|
|
||||||
latent_format = latent_formats.Flux
|
|
||||||
|
|
||||||
memory_usage_factor = 2.5
|
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
|
||||||
out = model_base.LongCatImage(self, device=device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
|
||||||
pref = self.text_encoder_key_prefix[0]
|
|
||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module):
|
|||||||
return getattr(self, self.lm_model).load_sd(sd)
|
return getattr(self, self.lm_model).load_sd(sd)
|
||||||
|
|
||||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||||
lm_metadata = token_weight_pairs.get("lm_metadata", {})
|
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||||
constant = self.constant
|
constant = self.constant
|
||||||
if comfy.model_management.should_use_bf16(device):
|
if comfy.model_management.should_use_bf16(device):
|
||||||
constant *= 0.5
|
constant *= 0.5
|
||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
num_tokens += lm_metadata.get("min_tokens", 0)
|
num_tokens += lm_metadata['min_tokens']
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||||
|
|||||||
@@ -33,8 +33,6 @@ class AnimaTokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def decode(self, token_ids, **kwargs):
|
|
||||||
return self.qwen3_06b.decode(token_ids, **kwargs)
|
|
||||||
|
|
||||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ import torch.nn as nn
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any, Tuple
|
from typing import Optional, Any, Tuple
|
||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@@ -105,7 +103,6 @@ class Qwen3_06BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [151643, 151645]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_06B_ACE15_Config:
|
class Qwen3_06B_ACE15_Config:
|
||||||
@@ -129,7 +126,6 @@ class Qwen3_06B_ACE15_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [151643, 151645]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_2B_ACE15_lm_Config:
|
class Qwen3_2B_ACE15_lm_Config:
|
||||||
@@ -153,7 +149,6 @@ class Qwen3_2B_ACE15_lm_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [151643, 151645]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_4B_ACE15_lm_Config:
|
class Qwen3_4B_ACE15_lm_Config:
|
||||||
@@ -177,7 +172,6 @@ class Qwen3_4B_ACE15_lm_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [151643, 151645]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_4BConfig:
|
class Qwen3_4BConfig:
|
||||||
@@ -201,7 +195,6 @@ class Qwen3_4BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [151643, 151645]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_8BConfig:
|
class Qwen3_8BConfig:
|
||||||
@@ -225,7 +218,6 @@ class Qwen3_8BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [151643, 151645]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Ovis25_2BConfig:
|
class Ovis25_2BConfig:
|
||||||
@@ -296,7 +288,6 @@ class Gemma2_2B_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [1]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3_4B_Config:
|
class Gemma3_4B_Config:
|
||||||
@@ -321,14 +312,6 @@ class Gemma3_4B_Config:
|
|||||||
rope_scale = [8.0, 1.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
stop_tokens = [1, 106]
|
|
||||||
|
|
||||||
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
|
|
||||||
vision_config = GEMMA3_VISION_CONFIG
|
|
||||||
mm_tokens_per_image = 256
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3_12B_Config:
|
class Gemma3_12B_Config:
|
||||||
@@ -353,9 +336,8 @@ class Gemma3_12B_Config:
|
|||||||
rope_scale = [8.0, 1.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
vision_config = GEMMA3_VISION_CONFIG
|
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||||
mm_tokens_per_image = 256
|
mm_tokens_per_image = 256
|
||||||
stop_tokens = [1, 106]
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
@@ -459,10 +441,8 @@ class Attention(nn.Module):
|
|||||||
freqs_cis: Optional[torch.Tensor] = None,
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
optimized_attention=None,
|
optimized_attention=None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
xq = self.q_proj(hidden_states)
|
xq = self.q_proj(hidden_states)
|
||||||
xk = self.k_proj(hidden_states)
|
xk = self.k_proj(hidden_states)
|
||||||
xv = self.v_proj(hidden_states)
|
xv = self.v_proj(hidden_states)
|
||||||
@@ -497,11 +477,6 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
present_key_value = (xk, xv, index + num_tokens)
|
present_key_value = (xk, xv, index + num_tokens)
|
||||||
|
|
||||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
|
||||||
xk = xk[:, :, -sliding_window:]
|
|
||||||
xv = xv[:, :, -sliding_window:]
|
|
||||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
|
||||||
|
|
||||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
|
||||||
@@ -584,12 +559,10 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
optimized_attention=None,
|
optimized_attention=None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
sliding_window = None
|
|
||||||
if self.transformer_type == 'gemma3':
|
if self.transformer_type == 'gemma3':
|
||||||
if self.sliding_attention:
|
if self.sliding_attention:
|
||||||
sliding_window = self.sliding_attention
|
|
||||||
if x.shape[1] > self.sliding_attention:
|
if x.shape[1] > self.sliding_attention:
|
||||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
|
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask + sliding_mask
|
attention_mask = attention_mask + sliding_mask
|
||||||
@@ -608,7 +581,6 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
freqs_cis=freqs_cis,
|
freqs_cis=freqs_cis,
|
||||||
optimized_attention=optimized_attention,
|
optimized_attention=optimized_attention,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
sliding_window=sliding_window,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.post_attention_layernorm(x)
|
x = self.post_attention_layernorm(x)
|
||||||
@@ -793,107 +765,6 @@ class BaseLlama:
|
|||||||
def forward(self, input_ids, *args, **kwargs):
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
return self.model(input_ids, *args, **kwargs)
|
return self.model(input_ids, *args, **kwargs)
|
||||||
|
|
||||||
class BaseGenerate:
|
|
||||||
def logits(self, x):
|
|
||||||
input = x[:, -1:]
|
|
||||||
if hasattr(self.model, "lm_head"):
|
|
||||||
module = self.model.lm_head
|
|
||||||
else:
|
|
||||||
module = self.model.embed_tokens
|
|
||||||
|
|
||||||
offload_stream = None
|
|
||||||
if module.comfy_cast_weights:
|
|
||||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
|
||||||
else:
|
|
||||||
weight = self.model.embed_tokens.weight.to(x)
|
|
||||||
|
|
||||||
x = torch.nn.functional.linear(input, weight, None)
|
|
||||||
|
|
||||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
|
||||||
device = embeds.device
|
|
||||||
model_config = self.model.config
|
|
||||||
|
|
||||||
if stop_tokens is None:
|
|
||||||
stop_tokens = self.model.config.stop_tokens
|
|
||||||
|
|
||||||
if execution_dtype is None:
|
|
||||||
if comfy.model_management.should_use_bf16(device):
|
|
||||||
execution_dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
execution_dtype = torch.float32
|
|
||||||
embeds = embeds.to(execution_dtype)
|
|
||||||
|
|
||||||
if embeds.ndim == 2:
|
|
||||||
embeds = embeds.unsqueeze(0)
|
|
||||||
|
|
||||||
past_key_values = [] #kv_cache init
|
|
||||||
max_cache_len = embeds.shape[1] + max_length
|
|
||||||
for x in range(model_config.num_hidden_layers):
|
|
||||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
|
||||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
|
||||||
|
|
||||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
|
||||||
|
|
||||||
generated_token_ids = []
|
|
||||||
pbar = comfy.utils.ProgressBar(max_length)
|
|
||||||
|
|
||||||
# Generation loop
|
|
||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
|
||||||
logits = self.logits(x)[:, -1]
|
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
|
||||||
token_id = next_token[0].item()
|
|
||||||
generated_token_ids.append(token_id)
|
|
||||||
|
|
||||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
if token_id in stop_tokens:
|
|
||||||
break
|
|
||||||
|
|
||||||
return generated_token_ids
|
|
||||||
|
|
||||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
|
||||||
|
|
||||||
if not do_sample or temperature == 0.0:
|
|
||||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Sampling mode
|
|
||||||
if repetition_penalty != 1.0:
|
|
||||||
for i in range(logits.shape[0]):
|
|
||||||
for token_id in set(token_history):
|
|
||||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
|
||||||
|
|
||||||
if min_p > 0.0:
|
|
||||||
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
|
||||||
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
|
|
||||||
min_threshold = min_p * top_probs
|
|
||||||
indices_to_remove = probs_before_filter < min_threshold
|
|
||||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
|
||||||
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
||||||
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 0] = False
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
|
||||||
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
||||||
|
|
||||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
|
||||||
|
|
||||||
class BaseQwen3:
|
class BaseQwen3:
|
||||||
def logits(self, x):
|
def logits(self, x):
|
||||||
input = x[:, -1:]
|
input = x[:, -1:]
|
||||||
@@ -937,7 +808,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen3_06BConfig(**config_dict)
|
config = Qwen3_06BConfig(**config_dict)
|
||||||
@@ -964,7 +835,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen3_4BConfig(**config_dict)
|
config = Qwen3_4BConfig(**config_dict)
|
||||||
@@ -982,7 +853,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen3_8BConfig(**config_dict)
|
config = Qwen3_8BConfig(**config_dict)
|
||||||
@@ -1000,7 +871,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen25_7BVLI_Config(**config_dict)
|
config = Qwen25_7BVLI_Config(**config_dict)
|
||||||
@@ -1010,9 +881,6 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
# todo: should this be tied or not?
|
|
||||||
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def preprocess_embed(self, embed, device):
|
def preprocess_embed(self, embed, device):
|
||||||
if embed["type"] == "image":
|
if embed["type"] == "image":
|
||||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||||
@@ -1046,7 +914,7 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
|
|
||||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
||||||
|
|
||||||
class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
|
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Gemma2_2B_Config(**config_dict)
|
config = Gemma2_2B_Config(**config_dict)
|
||||||
@@ -1055,7 +923,7 @@ class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
|
class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Gemma3_4B_Config(**config_dict)
|
config = Gemma3_4B_Config(**config_dict)
|
||||||
@@ -1064,25 +932,7 @@ class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
|
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
config = Gemma3_4B_Vision_Config(**config_dict)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
|
|
||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
|
||||||
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
|
||||||
self.image_size = config.vision_config["image_size"]
|
|
||||||
|
|
||||||
def preprocess_embed(self, embed, device):
|
|
||||||
if embed["type"] == "image":
|
|
||||||
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
|
||||||
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
|
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Gemma3_12B_Config(**config_dict)
|
config = Gemma3_12B_Config(**config_dict)
|
||||||
|
|||||||
@@ -1,184 +0,0 @@
|
|||||||
import re
|
|
||||||
import numbers
|
|
||||||
import torch
|
|
||||||
from comfy import sd1_clip
|
|
||||||
from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")]
|
|
||||||
QUOTE_PATTERN = "|".join(
|
|
||||||
[
|
|
||||||
re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2)
|
|
||||||
for q1, q2 in QUOTE_PAIRS
|
|
||||||
]
|
|
||||||
)
|
|
||||||
WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
|
|
||||||
|
|
||||||
|
|
||||||
def split_quotation(prompt):
|
|
||||||
matches = WORD_INTERNAL_QUOTE_RE.findall(prompt)
|
|
||||||
mapping = []
|
|
||||||
for i, word_src in enumerate(set(matches)):
|
|
||||||
word_tgt = "longcat_$##$_longcat" * (i + 1)
|
|
||||||
prompt = prompt.replace(word_src, word_tgt)
|
|
||||||
mapping.append((word_src, word_tgt))
|
|
||||||
|
|
||||||
parts = re.split(f"({QUOTE_PATTERN})", prompt)
|
|
||||||
result = []
|
|
||||||
for part in parts:
|
|
||||||
for word_src, word_tgt in mapping:
|
|
||||||
part = part.replace(word_tgt, word_src)
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
is_quoted = bool(re.match(QUOTE_PATTERN, part))
|
|
||||||
result.append((part, is_quoted))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.max_length = 512
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
|
||||||
parts = split_quotation(text)
|
|
||||||
all_tokens = []
|
|
||||||
for part_text, is_quoted in parts:
|
|
||||||
if is_quoted:
|
|
||||||
for char in part_text:
|
|
||||||
ids = self.tokenizer(char, add_special_tokens=False)["input_ids"]
|
|
||||||
all_tokens.extend(ids)
|
|
||||||
else:
|
|
||||||
ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"]
|
|
||||||
all_tokens.extend(ids)
|
|
||||||
|
|
||||||
if len(all_tokens) > self.max_length:
|
|
||||||
all_tokens = all_tokens[: self.max_length]
|
|
||||||
logger.warning(f"Truncated prompt to {self.max_length} tokens")
|
|
||||||
|
|
||||||
output = [(t, 1.0) for t in all_tokens]
|
|
||||||
# Pad to max length
|
|
||||||
self.pad_tokens(output, self.max_length - len(output))
|
|
||||||
return [output]
|
|
||||||
|
|
||||||
|
|
||||||
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
||||||
super().__init__(
|
|
||||||
embedding_directory=embedding_directory,
|
|
||||||
tokenizer_data=tokenizer_data,
|
|
||||||
name="qwen25_7b",
|
|
||||||
tokenizer=LongCatImageBaseTokenizer,
|
|
||||||
)
|
|
||||||
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
|
||||||
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
|
||||||
skip_template = False
|
|
||||||
if text.startswith("<|im_start|>"):
|
|
||||||
skip_template = True
|
|
||||||
if text.startswith("<|start_header_id|>"):
|
|
||||||
skip_template = True
|
|
||||||
if text == "":
|
|
||||||
text = " "
|
|
||||||
|
|
||||||
base_tok = getattr(self, "qwen25_7b")
|
|
||||||
if skip_template:
|
|
||||||
tokens = super().tokenize_with_weights(
|
|
||||||
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefix_ids = base_tok.tokenizer(
|
|
||||||
self.longcat_template_prefix, add_special_tokens=False
|
|
||||||
)["input_ids"]
|
|
||||||
suffix_ids = base_tok.tokenizer(
|
|
||||||
self.longcat_template_suffix, add_special_tokens=False
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
prompt_tokens = base_tok.tokenize_with_weights(
|
|
||||||
text, return_word_ids=return_word_ids, **kwargs
|
|
||||||
)
|
|
||||||
prompt_pairs = prompt_tokens[0]
|
|
||||||
|
|
||||||
prefix_pairs = [(t, 1.0) for t in prefix_ids]
|
|
||||||
suffix_pairs = [(t, 1.0) for t in suffix_ids]
|
|
||||||
|
|
||||||
combined = prefix_pairs + prompt_pairs + suffix_pairs
|
|
||||||
tokens = {"qwen25_7b": [combined]}
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
class LongCatImageTEModel(sd1_clip.SD1ClipModel):
|
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
||||||
super().__init__(
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
name="qwen25_7b",
|
|
||||||
clip_model=Qwen25_7BVLIModel,
|
|
||||||
model_options=model_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
|
||||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
|
||||||
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
|
||||||
count_im_start = 0
|
|
||||||
if template_end == -1:
|
|
||||||
for i, v in enumerate(tok_pairs):
|
|
||||||
elem = v[0]
|
|
||||||
if not torch.is_tensor(elem):
|
|
||||||
if isinstance(elem, numbers.Integral):
|
|
||||||
if elem == 151644 and count_im_start < 2:
|
|
||||||
template_end = i
|
|
||||||
count_im_start += 1
|
|
||||||
|
|
||||||
if out.shape[1] > (template_end + 3):
|
|
||||||
if tok_pairs[template_end + 1][0] == 872:
|
|
||||||
if tok_pairs[template_end + 2][0] == 198:
|
|
||||||
template_end += 3
|
|
||||||
|
|
||||||
if template_end == -1:
|
|
||||||
template_end = 0
|
|
||||||
|
|
||||||
suffix_start = None
|
|
||||||
for i in range(len(tok_pairs) - 1, -1, -1):
|
|
||||||
elem = tok_pairs[i][0]
|
|
||||||
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
|
|
||||||
if elem == 151645:
|
|
||||||
suffix_start = i
|
|
||||||
break
|
|
||||||
|
|
||||||
out = out[:, template_end:]
|
|
||||||
|
|
||||||
if "attention_mask" in extra:
|
|
||||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
|
||||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
|
||||||
extra.pop("attention_mask")
|
|
||||||
|
|
||||||
if suffix_start is not None:
|
|
||||||
suffix_len = len(tok_pairs) - suffix_start
|
|
||||||
if suffix_len > 0 and out.shape[1] > suffix_len:
|
|
||||||
out = out[:, :-suffix_len]
|
|
||||||
if "attention_mask" in extra:
|
|
||||||
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
|
|
||||||
if extra["attention_mask"].sum() == torch.numel(
|
|
||||||
extra["attention_mask"]
|
|
||||||
):
|
|
||||||
extra.pop("attention_mask")
|
|
||||||
|
|
||||||
return out, pooled, extra
|
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
|
||||||
class LongCatImageTEModel_(LongCatImageTEModel):
|
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
||||||
if llama_quantization_metadata is not None:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
|
||||||
if dtype_llama is not None:
|
|
||||||
dtype = dtype_llama
|
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
|
||||||
|
|
||||||
return LongCatImageTEModel_
|
|
||||||
@@ -3,10 +3,9 @@ import os
|
|||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -23,119 +22,53 @@ def ltxv_te(*args, **kwargs):
|
|||||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3_Tokenizer():
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def state_dict(self):
|
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
|
|
||||||
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
|
|
||||||
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
|
|
||||||
|
|
||||||
if image is None:
|
|
||||||
images = []
|
|
||||||
else:
|
|
||||||
samples = image.movedim(-1, 1)
|
|
||||||
total = int(896 * 896)
|
|
||||||
|
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
|
||||||
width = round(samples.shape[3] * scale_by)
|
|
||||||
height = round(samples.shape[2] * scale_by)
|
|
||||||
|
|
||||||
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
|
||||||
images = [s[:, :, :, :3]]
|
|
||||||
|
|
||||||
if text.startswith('<start_of_turn>'):
|
|
||||||
skip_template = True
|
|
||||||
|
|
||||||
if skip_template:
|
|
||||||
llama_text = text
|
|
||||||
else:
|
|
||||||
if llama_template is None:
|
|
||||||
if len(images) > 0:
|
|
||||||
llama_text = self.llama_template_images.format(text)
|
|
||||||
else:
|
|
||||||
llama_text = self.llama_template.format(text)
|
|
||||||
else:
|
|
||||||
llama_text = llama_template.format(text)
|
|
||||||
|
|
||||||
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
|
||||||
|
|
||||||
if len(images) > 0:
|
|
||||||
embed_count = 0
|
|
||||||
for r in text_tokens:
|
|
||||||
for i, token in enumerate(r):
|
|
||||||
if token[0] == 262144 and embed_count < len(images):
|
|
||||||
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
|
|
||||||
embed_count += 1
|
|
||||||
return text_tokens
|
|
||||||
|
|
||||||
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
self.dtypes = set()
|
|
||||||
self.dtypes.add(dtype)
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
text = llama_template.format(text)
|
||||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
embed_count = 0
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
for k in text_tokens:
|
||||||
|
tt = text_tokens[k]
|
||||||
class DualLinearProjection(torch.nn.Module):
|
for r in tt:
|
||||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
for i in range(len(r)):
|
||||||
super().__init__()
|
if r[i][0] == 262144:
|
||||||
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
|
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||||
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
|
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
def forward(self, x):
|
return text_tokens
|
||||||
source_dim = x.shape[-1]
|
|
||||||
x = x.movedim(1, -1)
|
|
||||||
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
|
|
||||||
|
|
||||||
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
|
|
||||||
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
|
|
||||||
return torch.cat((video, audio), dim=-1)
|
|
||||||
|
|
||||||
class LTXAVTEModel(torch.nn.Module):
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
self.compat_mode = False
|
|
||||||
self.text_projection_type = text_projection_type
|
|
||||||
|
|
||||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
self.dtypes.add(dtype_llama)
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
operations = self.gemma3_12b.operations # TODO
|
operations = self.gemma3_12b.operations # TODO
|
||||||
|
|
||||||
if self.text_projection_type == "single_linear":
|
|
||||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
elif self.text_projection_type == "dual_linear":
|
|
||||||
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_compat_mode(self): # TODO: remove
|
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
|
||||||
operations = self.gemma3_12b.operations
|
|
||||||
dtype = self.text_embedding_projection.weight.dtype
|
|
||||||
device = self.text_embedding_projection.weight.device
|
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
split_rope=True,
|
split_rope=True,
|
||||||
double_precision_rope=True,
|
double_precision_rope=True,
|
||||||
@@ -151,7 +84,6 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
self.compat_mode = True
|
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
self.execution_device = options.get("execution_device", self.execution_device)
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
@@ -169,57 +101,35 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
if self.text_projection_type == "single_linear":
|
|
||||||
out = out.movedim(1, -1).to(self.execution_device)
|
out = out.movedim(1, -1).to(self.execution_device)
|
||||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
out = self.text_embedding_projection(out)
|
out = self.text_embedding_projection(out)
|
||||||
|
out = out.float()
|
||||||
if self.compat_mode:
|
|
||||||
out_vid = self.video_embeddings_connector(out)[0]
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
out_audio = self.audio_embeddings_connector(out)[0]
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
extra = {}
|
|
||||||
else:
|
|
||||||
extra = {"unprocessed_ltxav_embeds": True}
|
|
||||||
elif self.text_projection_type == "dual_linear":
|
|
||||||
out = self.text_embedding_projection(out)
|
|
||||||
extra = {"unprocessed_ltxav_embeds": True}
|
|
||||||
|
|
||||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
return out.to(out_device), pooled
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
|
||||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
return self.gemma3_12b.load_sd(sd)
|
return self.gemma3_12b.load_sd(sd)
|
||||||
else:
|
else:
|
||||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
|
|
||||||
missing_all = []
|
missing_all = []
|
||||||
unexpected_all = []
|
unexpected_all = []
|
||||||
|
|
||||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
|
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||||
if component_sd:
|
if component_sd:
|
||||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||||
|
|
||||||
if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove
|
|
||||||
ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None)
|
|
||||||
if ww is not None:
|
|
||||||
if ww.shape[0] == 3840:
|
|
||||||
self.enable_compat_mode()
|
|
||||||
sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True)
|
|
||||||
self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False))
|
|
||||||
sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True)
|
|
||||||
self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False))
|
|
||||||
|
|
||||||
return (missing_all, unexpected_all)
|
return (missing_all, unexpected_all)
|
||||||
|
|
||||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||||
@@ -228,13 +138,11 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
constant /= 2.0
|
constant /= 2.0
|
||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
|
num_tokens = max(num_tokens, 64)
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
|
|
||||||
num_tokens = max(num_tokens, 642)
|
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class LTXAVTEModel_(LTXAVTEModel):
|
class LTXAVTEModel_(LTXAVTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
@@ -242,26 +150,5 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection
|
|||||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return LTXAVTEModel_
|
return LTXAVTEModel_
|
||||||
|
|
||||||
|
|
||||||
def sd_detect(state_dict_list, prefix=""):
|
|
||||||
for sd in state_dict_list:
|
|
||||||
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
|
|
||||||
return {"text_projection_type": "dual_linear"}
|
|
||||||
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
|
|
||||||
return {"text_projection_type": "single_linear"}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
|
||||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
||||||
if llama_quantization_metadata is not None:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
|
||||||
if dtype_llama is not None:
|
|
||||||
dtype = dtype_llama
|
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
|
||||||
return Gemma3_12BModel_
|
|
||||||
|
|||||||
@@ -1,23 +1,23 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.llama
|
import comfy.text_encoders.llama
|
||||||
from comfy.text_encoders.lt import Gemma3_Tokenizer
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
special_tokens = {"<end_of_turn>": 107}
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -40,20 +40,6 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
|
||||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
|
||||||
if llama_quantization_metadata is not None:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
|
||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
|
||||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
||||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
@@ -64,8 +50,6 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
|
|||||||
model = Gemma2_2BModel
|
model = Gemma2_2BModel
|
||||||
elif model_type == "gemma3_4b":
|
elif model_type == "gemma3_4b":
|
||||||
model = Gemma3_4BModel
|
model = Gemma3_4BModel
|
||||||
elif model_type == "gemma3_4b_vision":
|
|
||||||
model = Gemma3_4B_Vision_Model
|
|
||||||
|
|
||||||
class LuminaTEModel_(LuminaModel):
|
class LuminaTEModel_(LuminaModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
|||||||
@@ -6,10 +6,9 @@ class SPieceTokenizer:
|
|||||||
def from_pretrained(path, **kwargs):
|
def from_pretrained(path, **kwargs):
|
||||||
return SPieceTokenizer(path, **kwargs)
|
return SPieceTokenizer(path, **kwargs)
|
||||||
|
|
||||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
|
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
||||||
self.add_bos = add_bos
|
self.add_bos = add_bos
|
||||||
self.add_eos = add_eos
|
self.add_eos = add_eos
|
||||||
self.special_tokens = special_tokens
|
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
if torch.is_tensor(tokenizer_path):
|
if torch.is_tensor(tokenizer_path):
|
||||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||||
@@ -28,32 +27,8 @@ class SPieceTokenizer:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def __call__(self, string):
|
def __call__(self, string):
|
||||||
if self.special_tokens is not None:
|
|
||||||
import re
|
|
||||||
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
|
|
||||||
if special_tokens_pattern and re.search(special_tokens_pattern, string):
|
|
||||||
parts = re.split(f'({special_tokens_pattern})', string)
|
|
||||||
result = []
|
|
||||||
for part in parts:
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
if part in self.special_tokens:
|
|
||||||
result.append(self.special_tokens[part])
|
|
||||||
else:
|
|
||||||
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
|
|
||||||
result.extend(encoded)
|
|
||||||
return {"input_ids": result}
|
|
||||||
|
|
||||||
out = self.tokenizer.encode(string)
|
out = self.tokenizer.encode(string)
|
||||||
return {"input_ids": out}
|
return {"input_ids": out}
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=False):
|
|
||||||
|
|
||||||
if skip_special_tokens and self.special_tokens:
|
|
||||||
special_token_ids = set(self.special_tokens.values())
|
|
||||||
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
|
|
||||||
|
|
||||||
return self.tokenizer.decode(token_ids)
|
|
||||||
|
|
||||||
def serialize_model(self):
|
def serialize_model(self):
|
||||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user