mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 22:30:00 +00:00
Compare commits
46 Commits
fix/cpu-me
...
assets-red
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
810265e011 | ||
|
|
0b1b234d90 | ||
|
|
eb2b38458c | ||
|
|
03ddcaa3fa | ||
|
|
e7bebcc8d0 | ||
|
|
b2f6532b30 | ||
|
|
612893018c | ||
|
|
c0e26b93cc | ||
|
|
11da0e6c46 | ||
|
|
1e622d3923 | ||
|
|
eb78ea0cff | ||
|
|
6840ad0bbe | ||
|
|
2f0db0e680 | ||
|
|
69f6c37868 | ||
|
|
f484d66eb0 | ||
|
|
25f83d7401 | ||
|
|
2aafb71388 | ||
|
|
902e84d7ad | ||
|
|
d5e6e2a81f | ||
|
|
e735a8fd85 | ||
|
|
32ce7a70a7 | ||
|
|
cf950e47ab | ||
|
|
724145fb55 | ||
|
|
32d4888d99 | ||
|
|
b16390c2fd | ||
|
|
4866bbfd8c | ||
|
|
e17542b5c7 | ||
|
|
0bb6d3a3e9 | ||
|
|
6a450a8070 | ||
|
|
702cfcde3a | ||
|
|
8e9c801940 | ||
|
|
facda426b4 | ||
|
|
65a5992f2d | ||
|
|
287da646e5 | ||
|
|
63f9f1b11b | ||
|
|
9e3f559189 | ||
|
|
63c98d0c75 | ||
|
|
e69a5aa1be | ||
|
|
e0c063f93e | ||
|
|
6db4f4e3f1 | ||
|
|
41d364030b | ||
|
|
fab9b71f5d | ||
|
|
e5c1de4777 | ||
|
|
a5ed151e51 | ||
|
|
e527b72b09 | ||
|
|
f14129947c |
127
.coderabbit.yaml
127
.coderabbit.yaml
@@ -1,127 +0,0 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: false
|
||||
review_details: false
|
||||
commit_status: true
|
||||
collapse_walkthrough: true
|
||||
changed_files_summary: false
|
||||
sequence_diagrams: false
|
||||
estimate_code_review_effort: false
|
||||
assess_linked_issues: false
|
||||
related_issues: false
|
||||
related_prs: false
|
||||
suggested_labels: false
|
||||
auto_apply_labels: false
|
||||
suggested_reviewers: false
|
||||
auto_assign_reviewers: false
|
||||
in_progress_fortune: false
|
||||
enable_prompt_for_ai_agents: true
|
||||
|
||||
path_filters:
|
||||
- "!comfy_api_nodes/apis/**"
|
||||
- "!**/generated/*.pyi"
|
||||
- "!.ci/**"
|
||||
- "!script_examples/**"
|
||||
- "!**/__pycache__/**"
|
||||
- "!**/*.ipynb"
|
||||
- "!**/*.png"
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
- Backward compatibility (breaking changes affect all custom nodes)
|
||||
- Memory management and GPU resource handling
|
||||
- Performance implications in hot paths
|
||||
- Thread safety for concurrent execution
|
||||
- path: "comfy_api_nodes/**"
|
||||
instructions: |
|
||||
Third-party API integration nodes. Focus on:
|
||||
- No hardcoded API keys or secrets
|
||||
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
||||
- Correct Pydantic model usage
|
||||
- Security of user data passed to external APIs
|
||||
- path: "comfy_extras/**"
|
||||
instructions: |
|
||||
Community-contributed extra nodes. Focus on:
|
||||
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
||||
- No breaking changes to existing node interfaces
|
||||
- path: "comfy_execution/**"
|
||||
instructions: |
|
||||
Execution engine (graph execution, caching, jobs). Focus on:
|
||||
- Caching correctness
|
||||
- Concurrent execution safety
|
||||
- Graph validation edge cases
|
||||
- path: "nodes.py"
|
||||
instructions: |
|
||||
Core node definitions (2500+ lines). Focus on:
|
||||
- Backward compatibility of NODE_CLASS_MAPPINGS
|
||||
- Consistency of INPUT_TYPES return format
|
||||
- path: "alembic_db/**"
|
||||
instructions: |
|
||||
Database migrations. Focus on:
|
||||
- Migration safety and rollback support
|
||||
- Data preservation during schema changes
|
||||
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
unit_tests:
|
||||
enabled: false
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
enabled: false
|
||||
gitleaks:
|
||||
enabled: true
|
||||
shellcheck:
|
||||
enabled: false
|
||||
markdownlint:
|
||||
enabled: false
|
||||
yamllint:
|
||||
enabled: false
|
||||
languagetool:
|
||||
enabled: false
|
||||
github-checks:
|
||||
enabled: true
|
||||
timeout_ms: 90000
|
||||
ast-grep:
|
||||
essential_rules: true
|
||||
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
||||
knowledge_base:
|
||||
opt_out: false
|
||||
learnings:
|
||||
scope: "auto"
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,8 +7,6 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@@ -108,37 +106,3 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
30
.github/workflows/test-assets.yml
vendored
Normal file
30
.github/workflows/test-assets.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Assets Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Assets Tests
|
||||
run: |
|
||||
pip install -r tests-assets/requirements.txt
|
||||
python -m pytest tests-assets -v
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv*/
|
||||
venv/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
@@ -189,6 +189,8 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
@@ -225,11 +227,11 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
@@ -17,7 +17,7 @@ from importlib.metadata import version
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
@@ -45,7 +45,25 @@ def get_installed_frontend_version():
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-frontend-package=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
@@ -199,7 +217,25 @@ class FrontendManager:
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||
continue
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@@ -53,7 +53,7 @@ class SubgraphManager:
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Brightness slider -100..100
|
||||
uniform float u_float1; // Contrast slider -100..100
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float MID_GRAY = 0.18; // 18% reflectance
|
||||
|
||||
// sRGB gamma 2.2 approximation
|
||||
vec3 srgbToLinear(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(2.2));
|
||||
}
|
||||
|
||||
vec3 linearToSrgb(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||
}
|
||||
|
||||
float mapBrightness(float b) {
|
||||
return clamp(b / 100.0, -1.0, 1.0);
|
||||
}
|
||||
|
||||
float mapContrast(float c) {
|
||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 orig = texture(u_image0, v_texCoord);
|
||||
|
||||
float brightness = mapBrightness(u_float0);
|
||||
float contrast = mapContrast(u_float1);
|
||||
|
||||
vec3 lin = srgbToLinear(orig.rgb);
|
||||
|
||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||
|
||||
// Convert back to sRGB
|
||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||
|
||||
fragColor = vec4(result, orig.a);
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Mode
|
||||
uniform float u_float0; // Amount (0 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MODE_LINEAR = 0;
|
||||
const int MODE_RADIAL = 1;
|
||||
const int MODE_BARREL = 2;
|
||||
const int MODE_SWIRL = 3;
|
||||
const int MODE_DIAGONAL = 4;
|
||||
|
||||
const float AMOUNT_SCALE = 0.0005;
|
||||
const float RADIAL_MULT = 4.0;
|
||||
const float BARREL_MULT = 8.0;
|
||||
const float INV_SQRT2 = 0.70710678118;
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_texCoord;
|
||||
vec4 original = texture(u_image0, uv);
|
||||
|
||||
float amount = u_float0 * AMOUNT_SCALE;
|
||||
|
||||
if (amount < 0.000001) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
// Aspect-corrected coordinates for circular effects
|
||||
float aspect = u_resolution.x / u_resolution.y;
|
||||
vec2 centered = uv - 0.5;
|
||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||
float r = length(corrected);
|
||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||
vec2 offset = vec2(0.0);
|
||||
|
||||
if (u_int0 == MODE_LINEAR) {
|
||||
// Horizontal shift (no aspect correction needed)
|
||||
offset = vec2(amount, 0.0);
|
||||
}
|
||||
else if (u_int0 == MODE_RADIAL) {
|
||||
// Outward from center, stronger at edges
|
||||
offset = dir * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_BARREL) {
|
||||
// Lens distortion simulation (r² falloff)
|
||||
offset = dir * r * r * amount * BARREL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_SWIRL) {
|
||||
// Perpendicular to radial (rotational aberration)
|
||||
vec2 perp = vec2(-dir.y, dir.x);
|
||||
offset = perp * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_DIAGONAL) {
|
||||
// 45° offset (no aspect correction needed)
|
||||
offset = vec2(amount, amount) * INV_SQRT2;
|
||||
}
|
||||
|
||||
float red = texture(u_image0, uv + offset).r;
|
||||
float green = original.g;
|
||||
float blue = texture(u_image0, uv - offset).b;
|
||||
|
||||
fragColor = vec4(red, green, blue, original.a);
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // temperature (-100 to 100)
|
||||
uniform float u_float1; // tint (-100 to 100)
|
||||
uniform float u_float2; // vibrance (-100 to 100)
|
||||
uniform float u_float3; // saturation (-100 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float INPUT_SCALE = 0.01;
|
||||
const float TEMP_TINT_PRIMARY = 0.3;
|
||||
const float TEMP_TINT_SECONDARY = 0.15;
|
||||
const float VIBRANCE_BOOST = 2.0;
|
||||
const float SATURATION_BOOST = 2.0;
|
||||
const float SKIN_PROTECTION = 0.5;
|
||||
const float EPSILON = 0.001;
|
||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
// Scale inputs: -100/100 → -1/1
|
||||
float temperature = u_float0 * INPUT_SCALE;
|
||||
float tint = u_float1 * INPUT_SCALE;
|
||||
float vibrance = u_float2 * INPUT_SCALE;
|
||||
float saturation = u_float3 * INPUT_SCALE;
|
||||
|
||||
// Temperature (warm/cool): positive = warm, negative = cool
|
||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||
|
||||
// Tint (green/magenta): positive = green, negative = magenta
|
||||
color.g += tint * TEMP_TINT_PRIMARY;
|
||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||
|
||||
// Single clamp after temperature/tint
|
||||
color = clamp(color, 0.0, 1.0);
|
||||
|
||||
// Vibrance with skin protection
|
||||
if (vibrance != 0.0) {
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float sat = maxC - minC;
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
|
||||
if (vibrance < 0.0) {
|
||||
// Desaturate: -100 → gray
|
||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||
} else {
|
||||
// Boost less saturated colors more
|
||||
float vibranceAmt = vibrance * (1.0 - sat);
|
||||
|
||||
// Branchless skin tone protection
|
||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||
|
||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||
}
|
||||
}
|
||||
|
||||
// Saturation
|
||||
if (saturation != 0.0) {
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
float satMix = saturation < 0.0
|
||||
? 1.0 + saturation // -100 → gray
|
||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||
color = mix(vec3(gray), color, satMix);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MAX_RADIUS = 20;
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
// Perceptual luminance
|
||||
float getLuminance(vec3 rgb) {
|
||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
}
|
||||
|
||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||
float sigmaSpatial, float sigmaColor)
|
||||
{
|
||||
vec4 center = texture(u_image0, uv);
|
||||
vec3 centerRGB = center.rgb;
|
||||
|
||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||
|
||||
vec3 sumRGB = vec3(0.0);
|
||||
float sumWeight = 0.0;
|
||||
|
||||
int step = max(u_int0, 1);
|
||||
float radius2 = float(radius * radius);
|
||||
|
||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||
if (dy < -radius || dy > radius) continue;
|
||||
if (abs(dy) % step != 0) continue;
|
||||
|
||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||
if (dx < -radius || dx > radius) continue;
|
||||
if (abs(dx) % step != 0) continue;
|
||||
|
||||
vec2 offset = vec2(float(dx), float(dy));
|
||||
float dist2 = dot(offset, offset);
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||
|
||||
// Spatial Gaussian
|
||||
float spatialWeight = exp(dist2 * invSpatial2);
|
||||
|
||||
// Perceptual color distance (weighted RGB)
|
||||
vec3 diff = sampleRGB - centerRGB;
|
||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||
float colorWeight = exp(colorDist * invColor2);
|
||||
|
||||
float w = spatialWeight * colorWeight;
|
||||
sumRGB += sampleRGB * w;
|
||||
sumWeight += w;
|
||||
}
|
||||
}
|
||||
|
||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||
return vec4(resultRGB, center.a); // preserve center alpha
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||
int radius = int(radiusF + 0.5);
|
||||
|
||||
if (radius == 0) {
|
||||
fragColor = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Edge threshold → color sigma
|
||||
// Squared curve for better low-end control
|
||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||
t *= t;
|
||||
float sigmaColor = mix(0.01, 0.5, t);
|
||||
|
||||
// Spatial sigma tied to radius
|
||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||
|
||||
fragColor = bilateralFilter(
|
||||
v_texCoord,
|
||||
texelSize,
|
||||
radius,
|
||||
sigmaSpatial,
|
||||
sigmaColor
|
||||
);
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// High-quality integer hash (pcg-like)
|
||||
uint pcg(uint v) {
|
||||
uint state = v * 747796405u + 2891336453u;
|
||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||
return (word >> 22u) ^ word;
|
||||
}
|
||||
|
||||
// 2D -> 1D hash input
|
||||
uint hash2d(uvec2 p) {
|
||||
return pcg(p.x + pcg(p.y));
|
||||
}
|
||||
|
||||
// Hash to float [0, 1]
|
||||
float hashf(uvec2 p) {
|
||||
return float(hash2d(p)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Hash to float with offset (for RGB channels)
|
||||
float hashf(uvec2 p, uint offset) {
|
||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||
// Using simple approximation: average of multiple samples
|
||||
float toGaussian(uvec2 p) {
|
||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||
}
|
||||
|
||||
float toGaussian(uvec2 p, uint offset) {
|
||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||
return (sum - 2.0) * 0.7;
|
||||
}
|
||||
|
||||
// Smooth noise with better interpolation
|
||||
float smoothNoise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
// Quintic interpolation (less banding than cubic)
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
float smoothNoise(vec2 p, uint offset) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui, offset);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// Luminance (Rec.709)
|
||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||
|
||||
// Grain UV (resolution-independent)
|
||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||
uvec2 grainPixel = uvec2(grainUV);
|
||||
|
||||
float g;
|
||||
vec3 grainRGB;
|
||||
|
||||
if (u_int0 == 1) {
|
||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||
g = toGaussian(grainPixel);
|
||||
grainRGB = vec3(
|
||||
toGaussian(grainPixel, 100u),
|
||||
toGaussian(grainPixel, 200u),
|
||||
toGaussian(grainPixel, 300u)
|
||||
);
|
||||
} else {
|
||||
// Smooth mode: interpolated with quintic curve
|
||||
g = smoothNoise(grainUV);
|
||||
grainRGB = vec3(
|
||||
smoothNoise(grainUV, 100u),
|
||||
smoothNoise(grainUV, 200u),
|
||||
smoothNoise(grainUV, 300u)
|
||||
);
|
||||
}
|
||||
|
||||
// Luminance weighting (less grain in highlights)
|
||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||
|
||||
// Strength
|
||||
float strength = u_float0 * 0.15;
|
||||
|
||||
// Color vs monochrome grain
|
||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||
|
||||
color.rgb += grainColor * strength * lumWeight;
|
||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
#version 300 es
|
||||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
uniform float u_float1; // Radius
|
||||
uniform float u_float2; // Threshold
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int BLEND_ADD = 0;
|
||||
const int BLEND_SCREEN = 1;
|
||||
const int BLEND_SOFT = 2;
|
||||
const int BLEND_OVERLAY = 3;
|
||||
const int BLEND_LIGHTEN = 4;
|
||||
|
||||
const float GOLDEN_ANGLE = 2.39996323;
|
||||
const int MAX_SAMPLES = 48;
|
||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
float hash(vec2 p) {
|
||||
p = fract(p * vec2(123.34, 456.21));
|
||||
p += dot(p, p + 45.32);
|
||||
return fract(p.x * p.y);
|
||||
}
|
||||
|
||||
vec3 hexToRgb(int h) {
|
||||
return vec3(
|
||||
float((h >> 16) & 255),
|
||||
float((h >> 8) & 255),
|
||||
float(h & 255)
|
||||
) * (1.0 / 255.0);
|
||||
}
|
||||
|
||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||
if (mode == BLEND_SCREEN) {
|
||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||
}
|
||||
if (mode == BLEND_SOFT) {
|
||||
return mix(
|
||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||
step(0.5, glow)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_OVERLAY) {
|
||||
return mix(
|
||||
2.0 * base * glow,
|
||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||
step(0.5, base)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_LIGHTEN) {
|
||||
return max(base, glow);
|
||||
}
|
||||
return base + glow;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float intensity = u_float0 * 0.05;
|
||||
float radius = u_float1 * u_float1 * 0.012;
|
||||
|
||||
if (intensity < 0.001 || radius < 0.1) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
float threshold = 1.0 - u_float2 * 0.01;
|
||||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||
|
||||
float noise = hash(gl_FragCoord.xy);
|
||||
float angleOffset = noise * GOLDEN_ANGLE;
|
||||
float radiusJitter = 0.85 + noise * 0.3;
|
||||
|
||||
float ca = cos(GOLDEN_ANGLE);
|
||||
float sa = sin(GOLDEN_ANGLE);
|
||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||
|
||||
vec3 glow = vec3(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
// Center tap
|
||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||
glow += original.rgb * centerMask * 2.0;
|
||||
totalWeight += 2.0;
|
||||
|
||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||
if (i >= samples) break;
|
||||
|
||||
float fi = float(i);
|
||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||
|
||||
vec2 offset = dir * dist * texelSize;
|
||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||
|
||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||
w = max(w, 0.0);
|
||||
w *= w;
|
||||
|
||||
glow += c * mask * w;
|
||||
totalWeight += w;
|
||||
|
||||
dir = vec2(
|
||||
dir.x * ca - dir.y * sa,
|
||||
dir.x * sa + dir.y * ca
|
||||
);
|
||||
}
|
||||
|
||||
glow *= intensity / max(totalWeight, 0.001);
|
||||
|
||||
if (u_int1 > 0) {
|
||||
glow *= hexToRgb(u_int1);
|
||||
}
|
||||
|
||||
vec3 result = blend(original.rgb, glow, u_int0);
|
||||
result += (noise - 0.5) * (1.0 / 255.0);
|
||||
|
||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||
uniform float u_float0; // Hue (-180 to 180)
|
||||
uniform float u_float1; // Saturation (-100 to 100)
|
||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
// Color range modes
|
||||
const int MODE_MASTER = 0;
|
||||
const int MODE_RED = 1;
|
||||
const int MODE_YELLOW = 2;
|
||||
const int MODE_GREEN = 3;
|
||||
const int MODE_CYAN = 4;
|
||||
const int MODE_BLUE = 5;
|
||||
const int MODE_MAGENTA = 6;
|
||||
const int MODE_COLORIZE = 7;
|
||||
|
||||
// Color space modes
|
||||
const int COLORSPACE_HSL = 0;
|
||||
const int COLORSPACE_HSB = 1;
|
||||
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
//=============================================================================
|
||||
// RGB <-> HSL Conversions
|
||||
//=============================================================================
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = 0.0;
|
||||
float l = (maxC + minC) * 0.5;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
s = l < 0.5
|
||||
? delta / (maxC + minC)
|
||||
: delta / (2.0 - maxC - minC);
|
||||
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
t = fract(t);
|
||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 0.5) return q;
|
||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||
|
||||
float q = hsl.z < 0.5
|
||||
? hsl.z * (1.0 + hsl.y)
|
||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||
float p = 2.0 * hsl.z - q;
|
||||
|
||||
return vec3(
|
||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||
hue2rgb(p, q, hsl.x),
|
||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||
);
|
||||
}
|
||||
|
||||
vec3 rgb2hsb(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||
float b = maxC;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, b);
|
||||
}
|
||||
|
||||
vec3 hsb2rgb(vec3 hsb) {
|
||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Color Range Weight Calculation
|
||||
//=============================================================================
|
||||
|
||||
float hueDistance(float a, float b) {
|
||||
float d = abs(a - b);
|
||||
return min(d, 1.0 - d);
|
||||
}
|
||||
|
||||
float getHueWeight(float hue, float center, float overlap) {
|
||||
float baseWidth = 1.0 / 6.0;
|
||||
float feather = baseWidth * overlap;
|
||||
|
||||
float d = hueDistance(hue, center);
|
||||
|
||||
float inner = baseWidth * 0.5;
|
||||
float outer = inner + feather;
|
||||
|
||||
return 1.0 - smoothstep(inner, outer, d);
|
||||
}
|
||||
|
||||
float getModeWeight(float hue, int mode, float overlap) {
|
||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||
|
||||
if (mode == MODE_RED) {
|
||||
return max(
|
||||
getHueWeight(hue, 0.0, overlap),
|
||||
getHueWeight(hue, 1.0, overlap)
|
||||
);
|
||||
}
|
||||
|
||||
float center = float(mode - 1) / 6.0;
|
||||
return getHueWeight(hue, center, overlap);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Adjustment Functions
|
||||
//=============================================================================
|
||||
|
||||
float adjustLightness(float l, float amount) {
|
||||
return amount > 0.0
|
||||
? l + (1.0 - l) * amount
|
||||
: l + l * amount;
|
||||
}
|
||||
|
||||
float adjustBrightness(float b, float amount) {
|
||||
return clamp(b + amount, 0.0, 1.0);
|
||||
}
|
||||
|
||||
float adjustSaturation(float s, float amount) {
|
||||
return amount > 0.0
|
||||
? s + (1.0 - s) * amount
|
||||
: s + s * amount;
|
||||
}
|
||||
|
||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
float l = adjustLightness(lum, light);
|
||||
|
||||
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||
return hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Main
|
||||
//=============================================================================
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == MODE_COLORIZE) {
|
||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||
fragColor = vec4(result, original.a);
|
||||
return;
|
||||
}
|
||||
|
||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||
? rgb2hsl(original.rgb)
|
||||
: rgb2hsb(original.rgb);
|
||||
|
||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||
|
||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||
weight = 0.0;
|
||||
}
|
||||
|
||||
if (weight > EPSILON) {
|
||||
float h = fract(hsx.x + hueShift * weight);
|
||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||
float v = (u_int1 == COLORSPACE_HSL)
|
||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||
|
||||
vec3 adjusted = vec3(h, s, v);
|
||||
result = (u_int1 == COLORSPACE_HSL)
|
||||
? hsl2rgb(adjusted)
|
||||
: hsb2rgb(adjusted);
|
||||
} else {
|
||||
result = original.rgb;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, original.a);
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
#version 300 es
|
||||
#pragma passes 2
|
||||
precision highp float;
|
||||
|
||||
// Blur type constants
|
||||
const int BLUR_GAUSSIAN = 0;
|
||||
const int BLUR_BOX = 1;
|
||||
const int BLUR_RADIAL = 2;
|
||||
|
||||
// Radial blur config
|
||||
const int RADIAL_SAMPLES = 12;
|
||||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
if (u_int0 == BLUR_RADIAL) {
|
||||
// Only execute on first pass
|
||||
if (u_pass > 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec2 center = vec2(0.5);
|
||||
vec2 dir = v_texCoord - center;
|
||||
float dist = length(dir);
|
||||
|
||||
if (dist < 1e-4) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec4 sum = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float angleStep = radius * RADIAL_STRENGTH;
|
||||
|
||||
dir /= dist;
|
||||
|
||||
float cosStep = cos(angleStep);
|
||||
float sinStep = sin(angleStep);
|
||||
|
||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||
vec2 rotDir = vec2(
|
||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||
);
|
||||
|
||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||
vec2 uv = center + rotDir * dist;
|
||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||
sum += texture(u_image0, uv) * w;
|
||||
totalWeight += w;
|
||||
|
||||
rotDir = vec2(
|
||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||
rotDir.x * sinStep + rotDir.y * cosStep
|
||||
);
|
||||
}
|
||||
|
||||
fragColor0 = sum / max(totalWeight, 0.001);
|
||||
return;
|
||||
}
|
||||
|
||||
// Separable Gaussian / Box blur
|
||||
int samples = int(ceil(radius));
|
||||
|
||||
if (samples == 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||
|
||||
vec4 color = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
for (int i = -samples; i <= samples; i++) {
|
||||
vec2 offset = dir * float(i) * texelSize;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float weight;
|
||||
if (u_int0 == BLUR_GAUSSIAN) {
|
||||
weight = gaussian(float(i), sigma);
|
||||
} else {
|
||||
// BLUR_BOX
|
||||
weight = 1.0;
|
||||
}
|
||||
|
||||
color += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
|
||||
fragColor0 = color / totalWeight;
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
layout(location = 1) out vec4 fragColor1;
|
||||
layout(location = 2) out vec4 fragColor2;
|
||||
layout(location = 3) out vec4 fragColor3;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
// Output each channel as grayscale to separate render targets
|
||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
// Levels Adjustment
|
||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||
// u_float0: input black (0-255) default: 0
|
||||
// u_float1: input white (0-255) default: 255
|
||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||
// u_float3: output black (0-255) default: 0
|
||||
// u_float4: output white (0-255) default: 255
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, vec3(1.0 / gamma));
|
||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, 1.0 / gamma);
|
||||
result = mix(outBlack, outWhite, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 texColor = texture(u_image0, v_texCoord);
|
||||
vec3 color = texColor.rgb;
|
||||
|
||||
float inBlack = u_float0 / 255.0;
|
||||
float inWhite = u_float1 / 255.0;
|
||||
float gamma = u_float2;
|
||||
float outBlack = u_float3 / 255.0;
|
||||
float outWhite = u_float4 / 255.0;
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == 0) {
|
||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 1) {
|
||||
result = color;
|
||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 2) {
|
||||
result = color;
|
||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 3) {
|
||||
result = color;
|
||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else {
|
||||
result = color;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, texColor.a);
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
# GLSL Shader Sources
|
||||
|
||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
`{Blueprint_Name}_{node_id}.frag`
|
||||
|
||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||
- **node_id**: The GLSLShader node ID within the subgraph
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract shaders from blueprint JSONs to this folder
|
||||
python update_blueprints.py extract
|
||||
|
||||
# Patch edited shaders back into blueprint JSONs
|
||||
python update_blueprints.py patch
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run `extract` to pull current shaders from JSONs
|
||||
2. Edit `.frag` files
|
||||
3. Run `patch` to update the blueprint JSONs
|
||||
4. Test
|
||||
5. Commit both `.frag` files and updated JSONs
|
||||
@@ -1,28 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||
|
||||
// Edge enhancement (Laplacian)
|
||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||
|
||||
// Add edges back scaled by strength
|
||||
vec4 sharpened = center + edges * u_float0;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
float getLuminance(vec3 color) {
|
||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
// Gaussian blur for the "unsharp" mask
|
||||
int samples = int(ceil(radius));
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
vec4 blurred = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
for (int x = -samples; x <= samples; x++) {
|
||||
for (int y = -samples; y <= samples; y++) {
|
||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float dist = length(vec2(float(x), float(y)));
|
||||
float weight = gaussian(dist, sigma);
|
||||
blurred += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
}
|
||||
blurred /= totalWeight;
|
||||
|
||||
// Unsharp mask = original - blurred
|
||||
vec3 mask = original.rgb - blurred.rgb;
|
||||
|
||||
// Luminance-based threshold with smooth falloff
|
||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||
mask *= thresholdScale;
|
||||
|
||||
// Sharpen: original + mask * amount
|
||||
vec3 sharpened = original.rgb + mask * amount;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shader Blueprint Updater
|
||||
|
||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||
|
||||
File naming convention:
|
||||
{Blueprint Name}_{node_id}.frag
|
||||
|
||||
Usage:
|
||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||
python update_blueprints.py # Same as patch (default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLSL_DIR = Path(__file__).parent
|
||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||
|
||||
|
||||
def get_blueprint_files():
|
||||
"""Get all blueprint JSON files."""
|
||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
"""Convert blueprint name to safe filename."""
|
||||
return re.sub(r'[^\w\-]', '_', name)
|
||||
|
||||
|
||||
def extract_shaders():
|
||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||
extracted = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = json_path.stem
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
# Find GLSLShader nodes in subgraphs
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('type') == 'GLSLShader':
|
||||
node_id = node.get('id')
|
||||
widgets = node.get('widgets_values', [])
|
||||
|
||||
# Find shader code (first string that looks like GLSL)
|
||||
for widget in widgets:
|
||||
if isinstance(widget, str) and widget.startswith('#version'):
|
||||
safe_name = sanitize_filename(blueprint_name)
|
||||
frag_name = f"{safe_name}_{node_id}.frag"
|
||||
frag_path = GLSL_DIR / frag_name
|
||||
|
||||
with open(frag_path, 'w') as f:
|
||||
f.write(widget)
|
||||
|
||||
logger.info(" Extracted: %s", frag_name)
|
||||
extracted += 1
|
||||
break
|
||||
|
||||
logger.info("\nExtracted %d shader(s)", extracted)
|
||||
|
||||
|
||||
def patch_shaders():
|
||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||
shader_updates = {}
|
||||
|
||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||
parts = frag_path.stem.rsplit('_', 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||
continue
|
||||
|
||||
blueprint_name, node_id_str = parts
|
||||
|
||||
try:
|
||||
node_id = int(node_id_str)
|
||||
except ValueError:
|
||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||
continue
|
||||
|
||||
with open(frag_path, 'r') as f:
|
||||
shader_code = f.read()
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
shader_updates[blueprint_name] = []
|
||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||
|
||||
# Apply updates to JSON files
|
||||
patched = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = sanitize_filename(json_path.stem)
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error("Error reading %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
modified = False
|
||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||
# Find the node and update
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||
widgets = node.get('widgets_values', [])
|
||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||
widgets[0] = shader_code
|
||||
modified = True
|
||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||
patched += 1
|
||||
|
||||
if modified:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
if patched == 0:
|
||||
logger.info("No changes to apply.")
|
||||
else:
|
||||
logger.info("\nPatched %d shader(s)", patched)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
command = "patch"
|
||||
else:
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "extract":
|
||||
logger.info("Extracting shaders from blueprints...")
|
||||
extract_shaders()
|
||||
elif command in ("patch", "update", "apply"):
|
||||
logger.info("Patching shaders into blueprints...")
|
||||
patch_shaders()
|
||||
else:
|
||||
logger.info(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||
@@ -1 +0,0 @@
|
||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||
@@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
13
comfy/checkpoint_pickle.py
Normal file
13
comfy/checkpoint_pickle.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@@ -146,7 +146,6 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -258,6 +257,3 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -176,8 +176,6 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -4,25 +4,6 @@ import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
def is_equal(x, y):
|
||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||
return torch.equal(x, y)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
if x.keys() != y.keys():
|
||||
return False
|
||||
return all(is_equal(x[k], y[k]) for k in x)
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
if type(x) is not type(y) or len(x) != len(y):
|
||||
return False
|
||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||
else:
|
||||
try:
|
||||
return x == y
|
||||
except Exception:
|
||||
logging.warning("comparison issue with COND")
|
||||
return False
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -103,7 +84,7 @@ class CONDConstant(CONDRegular):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if not is_equal(self.cond, other.cond):
|
||||
if self.cond != other.cond:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
|
||||
@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
@@ -297,30 +297,6 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -584,7 +560,6 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@@ -630,53 +605,6 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -754,8 +682,6 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@@ -5,7 +5,7 @@ from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@@ -13,9 +13,6 @@ from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
@@ -755,10 +755,6 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
@@ -776,10 +772,3 @@ class ChromaRadiance(LatentFormat):
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
|
||||
class ZImagePixelSpace(ChromaRadiance):
|
||||
"""Pixel-space latent format for ZImage DCT variant.
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
@@ -195,20 +195,8 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
if text_ids is not None:
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@@ -3,6 +3,7 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@@ -152,7 +152,6 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@@ -229,7 +228,6 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@@ -4,6 +4,8 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@@ -143,7 +145,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@@ -176,7 +178,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -188,7 +190,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@@ -13,7 +13,6 @@ from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
@@ -335,7 +334,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@@ -463,8 +462,6 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@@ -514,7 +511,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -524,7 +521,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@@ -538,7 +535,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -557,7 +554,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@@ -565,8 +562,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@@ -838,8 +835,6 @@ class MiniTrainDIT(nn.Module):
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
orig_shape = list(x.shape)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
@@ -878,14 +873,6 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@@ -894,6 +881,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@@ -87,12 +87,20 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@@ -161,7 +169,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -189,6 +197,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@@ -196,9 +206,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@@ -217,23 +224,32 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
@@ -312,9 +328,6 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -324,12 +337,6 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@@ -29,34 +29,19 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -16,6 +16,7 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -80,7 +81,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
@@ -142,7 +143,6 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -232,7 +232,6 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -241,6 +241,7 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -304,7 +305,6 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -413,11 +413,10 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
@@ -436,9 +435,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -2,19 +2,13 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
ADALN_BASE_PARAMS_COUNT,
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
NormSingleLinearTextProjection,
|
||||
LTXVModel,
|
||||
apply_cross_attention_adaln,
|
||||
compute_prompt_timestep,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
@@ -92,8 +86,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -101,7 +93,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
@@ -109,7 +100,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -120,7 +110,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -132,7 +121,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -143,7 +131,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -156,7 +143,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -169,7 +155,6 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -182,16 +167,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
@@ -234,30 +214,10 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||
):
|
||||
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||
if self.cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(
|
||||
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||
)
|
||||
return apply_cross_attention_adaln(
|
||||
x, context, attn, shift_q, scale_q, gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask, transformer_options,
|
||||
)
|
||||
return attn(
|
||||
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||
mask=attention_mask, transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -273,17 +233,13 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self._apply_text_cross_attention(
|
||||
vx, v_context, self.attn2, self.scale_shift_table,
|
||||
getattr(self, 'prompt_scale_shift_table', None),
|
||||
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
@@ -297,11 +253,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self._apply_text_cross_attention(
|
||||
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
@@ -398,9 +350,6 @@ class LTXAVModel(LTXVModel):
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
apply_gated_attention=False,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -412,7 +361,6 @@ class LTXAVModel(LTXVModel):
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
self.apply_gated_attention = apply_gated_attention
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
@@ -437,8 +385,6 @@ class LTXAVModel(LTXVModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -453,28 +399,14 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=audio_embedding_coefficient,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=2,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_prompt_adaln_single = None
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
@@ -510,75 +442,14 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = lambda a: a
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||
num_layers = kwargs.get("connector_num_layers", 2)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||
if not unprocessed:
|
||||
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||
return context
|
||||
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||
context_vid = context[:, :, :self.cross_attention_dim]
|
||||
context_audio = context[:, :, self.cross_attention_dim:]
|
||||
else:
|
||||
context_vid = context
|
||||
context_audio = context
|
||||
if self.caption_proj_before_connector:
|
||||
context_vid = self.caption_projection(context_vid)
|
||||
context_audio = self.audio_caption_projection(context_audio)
|
||||
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -592,8 +463,6 @@ class LTXAVModel(LTXVModel):
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
apply_gated_attention=self.apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -715,10 +584,6 @@ class LTXAVModel(LTXVModel):
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
@@ -729,25 +594,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -771,40 +636,29 @@ class LTXAVModel(LTXVModel):
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
|
||||
a_prompt_timestep = compute_prompt_timestep(
|
||||
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
a_prompt_timestep = None
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
], None
|
||||
]
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
video_dim = vx.shape[-1]
|
||||
audio_dim = ax.shape[-1]
|
||||
|
||||
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||
|
||||
v_context, a_context = torch.split(
|
||||
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.caption_proj_before_connector is False:
|
||||
if self.audio_caption_projection is not None:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
@@ -848,7 +702,7 @@ class LTXAVModel(LTXVModel):
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
@@ -866,9 +720,6 @@ class LTXAVModel(LTXVModel):
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
v_prompt_timestep = timestep[3]
|
||||
a_prompt_timestep = timestep[4]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@@ -895,9 +746,6 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -918,9 +766,6 @@ class LTXAVModel(LTXVModel):
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
"v_prompt_timestep": v_prompt_timestep,
|
||||
"a_prompt_timestep": a_prompt_timestep,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -942,9 +787,6 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
v_prompt_timestep=v_prompt_timestep,
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
@@ -50,7 +50,6 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -64,7 +63,6 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -123,7 +121,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -148,7 +145,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -161,9 +157,11 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.empty(
|
||||
torch.rand(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
@@ -236,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
@@ -249,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -290,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@@ -15,8 +14,6 @@ import comfy.ldm.common_dit
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
@@ -275,30 +272,6 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NormSingleLinearTextProjection(nn.Module):
|
||||
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.in_norm = operations.RMSNorm(
|
||||
in_features, eps=1e-6, elementwise_affine=False
|
||||
)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.in_features = in_features
|
||||
|
||||
def forward(self, caption):
|
||||
caption = self.in_norm(caption)
|
||||
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||
return self.linear_1(caption)
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -367,7 +340,6 @@ class CrossAttention(nn.Module):
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -387,12 +359,6 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
@@ -414,30 +380,16 @@ class CrossAttention(nn.Module):
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||
out = out * gates.unsqueeze(-1)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||
ADALN_BASE_PARAMS_COUNT = 6
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
@@ -461,25 +413,18 @@ class BasicTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||
x += apply_cross_attention_adaln(
|
||||
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||
)
|
||||
else:
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
@@ -487,47 +432,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||
|
||||
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||
over text tokens. Returns None when *adaln_module* is None.
|
||||
"""
|
||||
if adaln_module is None:
|
||||
return None
|
||||
ts_input = (
|
||||
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||
if timestep_scaled.dim() > 1
|
||||
else timestep_scaled.flatten()
|
||||
)
|
||||
prompt_ts, _ = adaln_module(
|
||||
ts_input,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x, context, attn, q_shift, q_scale, q_gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask=None, transformer_options={},
|
||||
):
|
||||
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||
|
||||
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||
that both regular tensors and CompressedTimestep are supported.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
@@ -649,9 +553,6 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
caption_projection_first_linear=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -678,9 +579,6 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.caption_proj_before_connector = caption_proj_before_connector
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.caption_projection_first_linear = caption_projection_first_linear
|
||||
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -708,37 +606,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
else:
|
||||
self.prompt_adaln_single = None
|
||||
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.caption_projection = lambda a: a
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
@@ -760,16 +638,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Base implementation returns None (no attenuation). Subclasses that
|
||||
support guide-based attention control should override this.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@@ -784,9 +654,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep_scaled.flatten(),
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -796,18 +666,14 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
return timestep, embedded_timestep, prompt_timestep
|
||||
return timestep, embedded_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_proj_before_connector is False:
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
@@ -915,25 +781,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
merged_args["prompt_timestep"] = prompt_timestep
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Build self-attention mask for per-guide attenuation
|
||||
self_attention_mask = self._build_guide_self_attention_mask(
|
||||
x, transformer_options, merged_args
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
**merged_args,
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
)
|
||||
|
||||
# Process output
|
||||
@@ -957,9 +814,7 @@ class LTXVModel(LTXBaseModel):
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -978,8 +833,6 @@ class LTXVModel(LTXBaseModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -988,6 +841,7 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@@ -999,7 +853,6 @@ class LTXVModel(LTXBaseModel):
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -1037,257 +890,26 @@ class LTXVModel(LTXBaseModel):
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
|
||||
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||
# Each entry tracks one guide reference; they are appended in order and
|
||||
# their pre_filter_counts partition the kf_grid_mask.
|
||||
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_entries:
|
||||
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||
if total_pfc != len(kf_grid_mask):
|
||||
raise ValueError(
|
||||
f"guide pre_filter_counts ({total_pfc}) != "
|
||||
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||
)
|
||||
resolved_entries = []
|
||||
offset = 0
|
||||
for entry in guide_entries:
|
||||
pfc = entry["pre_filter_count"]
|
||||
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||
surviving = int(entry_mask.sum().item())
|
||||
resolved_entries.append({
|
||||
**entry,
|
||||
"surviving_count": surviving,
|
||||
})
|
||||
offset += pfc
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||
attention for each guide reference independently.
|
||||
|
||||
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||
spatial masks, or no guide tokens).
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
# AV model: x = [vx, ax]; use vx for token count and device
|
||||
total_tokens = x[0].shape[1]
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
else:
|
||||
total_tokens = x.shape[1]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||
if num_guide_tokens == 0:
|
||||
return None
|
||||
|
||||
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_attenuation:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
# Guides are appended in order at the end of the sequence.
|
||||
guide_start = total_tokens - num_guide_tokens
|
||||
all_weights = []
|
||||
total_tracked = 0
|
||||
|
||||
for entry in resolved_entries:
|
||||
surviving = entry["surviving_count"]
|
||||
if surviving == 0:
|
||||
continue
|
||||
|
||||
strength = entry["strength"]
|
||||
pixel_mask = entry.get("pixel_mask")
|
||||
latent_shape = entry.get("latent_shape")
|
||||
|
||||
if pixel_mask is not None and latent_shape is not None:
|
||||
f_lat, h_lat, w_lat = latent_shape
|
||||
per_token = self._downsample_mask_to_latent(
|
||||
pixel_mask.to(device=device, dtype=dtype),
|
||||
f_lat, h_lat, w_lat,
|
||||
)
|
||||
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||
# Collapse batch dim — the mask is assumed identical across the
|
||||
# batch; validate and take the first element to get (1, tokens).
|
||||
if per_token.shape[0] > 1:
|
||||
ref = per_token[0]
|
||||
for bi in range(1, per_token.shape[0]):
|
||||
if not torch.equal(ref, per_token[bi]):
|
||||
logger.warning(
|
||||
"pixel_mask differs across batch elements; "
|
||||
"using first element only."
|
||||
)
|
||||
break
|
||||
per_token = per_token[:1]
|
||||
# `surviving` is the post-grid_mask token count.
|
||||
# Clamp to surviving to handle any mismatch safely.
|
||||
n_weights = min(per_token.shape[1], surviving)
|
||||
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||
else:
|
||||
weights = torch.full(
|
||||
(1, surviving), strength, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
all_weights.append(weights)
|
||||
total_tracked += weights.shape[1]
|
||||
|
||||
if not all_weights:
|
||||
return None
|
||||
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
return None
|
||||
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
"""Downsample a pixel-space mask to per-token latent weights.
|
||||
|
||||
Args:
|
||||
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||
f_lat: Number of latent frames (pre-dilation original count).
|
||||
h_lat: Latent height (pre-dilation original height).
|
||||
w_lat: Latent width (pre-dilation original width).
|
||||
|
||||
Returns:
|
||||
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||
"""
|
||||
b = mask.shape[0]
|
||||
f_pix = mask.shape[2]
|
||||
|
||||
# Spatial downsampling: area interpolation per frame
|
||||
spatial_down = torch.nn.functional.interpolate(
|
||||
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||
size=(h_lat, w_lat),
|
||||
mode="area",
|
||||
)
|
||||
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||
|
||||
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||
first_frame = spatial_down[:, :, :1, :, :]
|
||||
if f_pix > 1 and f_lat > 1:
|
||||
remaining_pix = f_pix - 1
|
||||
remaining_lat = f_lat - 1
|
||||
t = remaining_pix // remaining_lat
|
||||
if t < 1:
|
||||
# Fewer pixel frames than latent frames — upsample by repeating
|
||||
# the available pixel frames via nearest interpolation.
|
||||
rest_flat = rearrange(
|
||||
spatial_down[:, :, 1:, :, :],
|
||||
"b 1 f h w -> (b h w) 1 f",
|
||||
)
|
||||
rest_up = torch.nn.functional.interpolate(
|
||||
rest_flat, size=remaining_lat, mode="nearest",
|
||||
)
|
||||
rest = rearrange(
|
||||
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||
b=b, h=h_lat, w=w_lat,
|
||||
)
|
||||
else:
|
||||
# Trim trailing pixel frames that don't fill a complete group
|
||||
usable = remaining_lat * t
|
||||
rest = rearrange(
|
||||
spatial_down[:, :, 1:1 + usable, :, :],
|
||||
"b 1 (f t) h w -> b 1 f t h w",
|
||||
t=t,
|
||||
)
|
||||
rest = rest.mean(dim=3)
|
||||
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||
elif f_lat > 1:
|
||||
# Single pixel frame but multiple latent frames — repeat the
|
||||
# single frame across all latent frames.
|
||||
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||
else:
|
||||
latent_mask = first_frame
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -1297,8 +919,6 @@ class LTXVModel(LTXBaseModel):
|
||||
timestep=timestep,
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
prompt_timestep=prompt_timestep,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
@@ -141,10 +141,7 @@ class AudioVAE(torch.nn.Module):
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
if "bwe" in component_config.vocoder:
|
||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||
else:
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
@@ -822,23 +822,26 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.get_default_config()
|
||||
config = self._guess_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = model_config.get(
|
||||
"sampling_rate", config.get("sampling_rate", 16000)
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
@@ -847,38 +850,44 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def get_default_config(self):
|
||||
ddconfig = {
|
||||
"double_z": True,
|
||||
"mel_bins": 64,
|
||||
"z_channels": 8,
|
||||
"resolution": 256,
|
||||
"downsample_time": False,
|
||||
"in_channels": 2,
|
||||
"out_ch": 2,
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [],
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"ddconfig": ddconfig,
|
||||
"sampling_rate": 16000,
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
}
|
||||
},
|
||||
"preprocessing": {
|
||||
"stft": {
|
||||
"filter_length": 1024,
|
||||
"hop_length": 160,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
@@ -15,9 +15,6 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def in_meta_context():
|
||||
return torch.device("meta") == torch.empty(0).device
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
@@ -353,10 +350,6 @@ class Decoder(nn.Module):
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_space":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_time":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@@ -402,21 +395,17 @@ class Decoder(nn.Module):
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -466,15 +455,6 @@ class Decoder(nn.Module):
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
else:
|
||||
self.register_buffer(
|
||||
"last_scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(2, output_channel),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
@@ -903,15 +883,6 @@ class ResnetBlock3D(nn.Module):
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(4, in_channels),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
@@ -1041,6 +1012,9 @@ class processor(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
@@ -1053,12 +1027,9 @@ class VideoVAE(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.get_default_config(version)
|
||||
config = self.guess_config(version)
|
||||
|
||||
self.config = config
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
@@ -1073,7 +1044,6 @@ class VideoVAE(nn.Module):
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
base_channels=config.get("encoder_base_channels", 128),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@@ -1081,7 +1051,6 @@ class VideoVAE(nn.Module):
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
base_channels=config.get("decoder_base_channels", 128),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
@@ -1091,7 +1060,7 @@ class VideoVAE(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def get_default_config(self, version):
|
||||
def guess_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@@ -1198,7 +1167,8 @@ class VideoVAE(nn.Module):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -13,307 +12,6 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor):
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.0:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.0:
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
if even:
|
||||
time = torch.arange(-half_size, half_size) + 0.5
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride=1,
|
||||
padding=True,
|
||||
padding_mode="replicate",
|
||||
kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
|
||||
self.register_buffer("filter", filter, persistent=persistent)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio=2,
|
||||
down_ratio=2,
|
||||
up_kernel_size=12,
|
||||
down_kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||
)
|
||||
self.acts2 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HiFi-GAN residual blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
@@ -421,7 +119,6 @@ class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@@ -431,39 +128,19 @@ class Vocoder(torch.nn.Module):
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
activation = config.get("activation", "snake")
|
||||
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
|
||||
|
||||
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||
# where upsample_factor = product of all upsample stride lengths,
|
||||
# and mel_hop_length is loaded from the autoencoder config at
|
||||
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.resblock = config.get("resblock", "1")
|
||||
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
|
||||
if self.resblock == "1":
|
||||
resblock_cls = ResBlock1
|
||||
elif self.resblock == "2":
|
||||
resblock_cls = ResBlock2
|
||||
elif self.resblock == "AMP1":
|
||||
resblock_cls = AMPBlock1
|
||||
else:
|
||||
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
@@ -480,40 +157,25 @@ class Vocoder(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
if self.resblock == "AMP1":
|
||||
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, k, d))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
if self.resblock == "AMP1":
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(ch))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.conv_post = ops.Conv1d(
|
||||
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||
)
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [5, 4, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
"activation": "snake",
|
||||
"use_bias_at_final": True,
|
||||
"use_tanh_at_final": True,
|
||||
}
|
||||
|
||||
return config
|
||||
@@ -534,10 +196,8 @@ class Vocoder(torch.nn.Module):
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
if self.resblock != "AMP1":
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
@@ -546,167 +206,8 @@ class Vocoder(torch.nn.Module):
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = self.act_post(x)
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
The STFT is computed by convolving the padded signal with forward_basis.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
Computed in float32 for numerical stability, then cast back to
|
||||
the input dtype.
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
sampling_rate: int,
|
||||
mel_fmin: float,
|
||||
mel_fmax: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(
|
||||
self, y: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class VocoderWithBWE(torch.nn.Module):
|
||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||
|
||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
vocoder_config = config["vocoder"]
|
||||
bwe_config = config["bwe"]
|
||||
|
||||
self.vocoder = Vocoder(config=vocoder_config)
|
||||
self.bwe_generator = Vocoder(
|
||||
config={**bwe_config, "apply_final_activation": False}
|
||||
)
|
||||
|
||||
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||
self.hop_length = bwe_config["hop_length"]
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=bwe_config["n_fft"],
|
||||
hop_length=bwe_config["hop_length"],
|
||||
win_length=bwe_config["n_fft"],
|
||||
n_mel_channels=bwe_config["num_mels"],
|
||||
sampling_rate=bwe_config["input_sampling_rate"],
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||
)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||
persistent=False,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
def _compute_mel(self, audio):
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||
B, C, T = audio.shape
|
||||
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec):
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, T_low = x.shape
|
||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||
|
||||
remainder = T_low % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
mel = self._compute_mel(x)
|
||||
residual = self.bwe_generator(mel)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||
|
||||
@@ -14,7 +14,6 @@ from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
@@ -859,267 +858,3 @@ class NextDiT(nn.Module):
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
return -img
|
||||
|
||||
|
||||
#############################################################################
|
||||
# Pixel Space Decoder Components #
|
||||
#############################################################################
|
||||
|
||||
def _modulate_shift_scale(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class PixelResBlock(nn.Module):
|
||||
"""
|
||||
Residual block with AdaLN modulation, zero-initialised so it starts as
|
||||
an identity at the beginning of training.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
|
||||
h = self.mlp(h)
|
||||
return x + gate * h
|
||||
|
||||
|
||||
class DCTFinalLayer(nn.Module):
|
||||
"""Zero-initialised output projection (adopted from DiT)."""
|
||||
|
||||
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
|
||||
|
||||
class SimpleMLPAdaLN(nn.Module):
|
||||
"""
|
||||
Small MLP decoder head for the pixel-space variant.
|
||||
|
||||
Takes per-patch pixel values and a per-patch conditioning vector from the
|
||||
transformer backbone and predicts the denoised pixel values.
|
||||
|
||||
x : [B*N, P^2, C] – noisy pixel values per patch position
|
||||
c : [B*N, dim] – backbone hidden state per patch (conditioning)
|
||||
→ [B*N, P^2, C]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
z_channels: int,
|
||||
num_res_blocks: int,
|
||||
max_freqs: int = 8,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
# Project backbone hidden state → per-patch conditioning
|
||||
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
||||
|
||||
# Input projection with DCT positional encoding
|
||||
self.input_embedder = NerfEmbedder(
|
||||
in_channels=in_channels,
|
||||
hidden_size_input=model_channels,
|
||||
max_freqs=max_freqs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
self.res_blocks = nn.ModuleList([
|
||||
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
|
||||
])
|
||||
|
||||
# Output projection
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
||||
original_dtype = x.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
||||
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
||||
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
||||
x = x.to(weight_dtype)
|
||||
for block in self.res_blocks:
|
||||
x = block(x, y)
|
||||
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
|
||||
|
||||
|
||||
#############################################################################
|
||||
# NextDiT – Pixel Space #
|
||||
#############################################################################
|
||||
|
||||
class NextDiTPixelSpace(NextDiT):
|
||||
"""
|
||||
Pixel-space variant of NextDiT.
|
||||
|
||||
Identical transformer backbone to NextDiT, but the output head is replaced
|
||||
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
|
||||
per patch rather than a single affine projection.
|
||||
|
||||
Key differences vs NextDiT:
|
||||
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
|
||||
• ``_forward`` stores the raw patchified pixel values before the backbone
|
||||
embedding and feeds them to ``dec_net`` together with the per-patch
|
||||
backbone hidden states.
|
||||
• Supports optional x0 prediction via ``use_x0``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# decoder-specific
|
||||
decoder_hidden_size: int = 3840,
|
||||
decoder_num_res_blocks: int = 4,
|
||||
decoder_max_freqs: int = 8,
|
||||
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
|
||||
use_x0: bool = False,
|
||||
# all NextDiT args forwarded unchanged
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Remove the latent-space final layer – not used in pixel space
|
||||
del self.final_layer
|
||||
|
||||
patch_size = kwargs.get("patch_size", 2)
|
||||
in_channels = kwargs.get("in_channels", 4)
|
||||
dim = kwargs.get("dim", 4096)
|
||||
|
||||
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
|
||||
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
|
||||
|
||||
self.dec_net = SimpleMLPAdaLN(
|
||||
in_channels=dec_in_ch,
|
||||
model_channels=decoder_hidden_size,
|
||||
out_channels=dec_in_ch,
|
||||
z_channels=dim,
|
||||
num_res_blocks=decoder_num_res_blocks,
|
||||
max_freqs=decoder_max_freqs,
|
||||
dtype=kwargs.get("dtype"),
|
||||
device=kwargs.get("device"),
|
||||
operations=kwargs.get("operations"),
|
||||
)
|
||||
|
||||
if use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
|
||||
# with the pixel-space dec_net decoder.
|
||||
# ------------------------------------------------------------------
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||
adaln_input = t
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
|
||||
pH = pW = self.patch_size
|
||||
B, C, H, W = x.shape
|
||||
pixel_patches = (
|
||||
x.view(B, C, H // pH, pH, W // pW, pW)
|
||||
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
|
||||
.flatten(3) # [B, Ht, Wt, pH*pW*C]
|
||||
.flatten(1, 2) # [B, N, pH*pW*C]
|
||||
)
|
||||
N = pixel_patches.shape[1]
|
||||
# decoder sees one token per patch: [B*N, 1, P^2*C]
|
||||
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, adaln_input, num_tokens,
|
||||
ref_latents=ref_latents, ref_contexts=ref_contexts,
|
||||
siglip_feats=siglip_feats, transformer_options=transformer_options
|
||||
)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
|
||||
# img may have padding tokens beyond N; only the first N are real image patches
|
||||
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
|
||||
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
|
||||
|
||||
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
|
||||
output = output.reshape(B, N, -1) # [B, N, P^2*C]
|
||||
|
||||
# prepend zero cap placeholder so unpatchify indexing works unchanged
|
||||
cap_placeholder = torch.zeros(
|
||||
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
|
||||
)
|
||||
img_out = self.unpatchify(
|
||||
torch.cat([cap_placeholder, output], dim=1),
|
||||
img_size, cap_size, return_tensor=x_is_tensor
|
||||
)[:, :, :h, :w]
|
||||
|
||||
return -img_out
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
# _forward returns neg_x0 = -x0 (negated decoder output).
|
||||
#
|
||||
# Reference inference (working_inference_reference.py):
|
||||
# out = _forward(img, t) # = -x0
|
||||
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
|
||||
# img += (t_prev - t_curr) * pred # Euler step
|
||||
#
|
||||
# ComfyUI's Euler sampler does the same:
|
||||
# x_next = x + (sigma_next - sigma) * model_output
|
||||
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
|
||||
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
|
||||
|
||||
@@ -524,9 +524,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@@ -102,7 +102,19 @@ class VideoConv3d(nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
|
||||
@@ -18,8 +18,6 @@ import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
from ..sdpose import HeatmapHead
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
@@ -443,7 +441,6 @@ class UNetModel(nn.Module):
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
heatmap_head=False,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
@@ -830,9 +827,6 @@ class UNetModel(nn.Module):
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
if heatmap_head:
|
||||
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
class HeatmapHead(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=640,
|
||||
out_channels=133,
|
||||
input_size=(768, 1024),
|
||||
heatmap_scale=4,
|
||||
deconv_out_channels=(640,),
|
||||
deconv_kernel_sizes=(4,),
|
||||
conv_out_channels=(640,),
|
||||
conv_kernel_sizes=(1,),
|
||||
final_layer_kernel_size=1,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
|
||||
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
|
||||
|
||||
# Deconv layers
|
||||
if deconv_out_channels:
|
||||
deconv_layers = []
|
||||
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
|
||||
if kernel_size == 4:
|
||||
padding, output_padding = 1, 0
|
||||
elif kernel_size == 3:
|
||||
padding, output_padding = 1, 1
|
||||
elif kernel_size == 2:
|
||||
padding, output_padding = 0, 0
|
||||
else:
|
||||
raise ValueError(f'Unsupported kernel size {kernel_size}')
|
||||
|
||||
deconv_layers.extend([
|
||||
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
|
||||
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
|
||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||
torch.nn.SiLU(inplace=True)
|
||||
])
|
||||
in_channels = out_ch
|
||||
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
|
||||
else:
|
||||
self.deconv_layers = torch.nn.Identity()
|
||||
|
||||
# Conv layers
|
||||
if conv_out_channels:
|
||||
conv_layers = []
|
||||
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv_layers.extend([
|
||||
operations.Conv2d(in_channels, out_ch, kernel_size,
|
||||
stride=1, padding=padding, device=device, dtype=dtype),
|
||||
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||
torch.nn.SiLU(inplace=True)
|
||||
])
|
||||
in_channels = out_ch
|
||||
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
||||
else:
|
||||
self.conv_layers = torch.nn.Identity()
|
||||
|
||||
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x): # Decode heatmaps to keypoints
|
||||
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
|
||||
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
|
||||
B, K, H, W = heatmaps_np.shape
|
||||
|
||||
batch_keypoints = []
|
||||
batch_scores = []
|
||||
|
||||
for b in range(B):
|
||||
hm = heatmaps_np[b].copy() # (K, H, W)
|
||||
|
||||
# --- vectorised argmax ---
|
||||
flat = hm.reshape(K, -1)
|
||||
idx = np.argmax(flat, axis=1)
|
||||
scores = flat[np.arange(K), idx].copy()
|
||||
y_locs, x_locs = np.unravel_index(idx, (H, W))
|
||||
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
|
||||
invalid = scores <= 0.
|
||||
keypoints[invalid] = -1
|
||||
|
||||
# --- DARK sub-pixel refinement (UDP) ---
|
||||
# 1. Gaussian blur with max-preserving normalisation
|
||||
border = 5 # (kernel-1)//2 for kernel=11
|
||||
for k in range(K):
|
||||
origin_max = np.max(hm[k])
|
||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||
dr[border:-border, border:-border] = hm[k].copy()
|
||||
dr = gaussian_filter(dr, sigma=2.0)
|
||||
hm[k] = dr[border:-border, border:-border].copy()
|
||||
cur_max = np.max(hm[k])
|
||||
if cur_max > 0:
|
||||
hm[k] *= origin_max / cur_max
|
||||
# 2. Log-space for Taylor expansion
|
||||
np.clip(hm, 1e-3, 50., hm)
|
||||
np.log(hm, hm)
|
||||
# 3. Hessian-based Newton step
|
||||
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
|
||||
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
|
||||
index += (W + 2) * (H + 2) * np.arange(0, K)
|
||||
index = index.astype(int).reshape(-1, 1)
|
||||
i_ = hm_pad[index]
|
||||
ix1 = hm_pad[index + 1]
|
||||
iy1 = hm_pad[index + W + 2]
|
||||
ix1y1 = hm_pad[index + W + 3]
|
||||
ix1_y1_ = hm_pad[index - W - 3]
|
||||
ix1_ = hm_pad[index - 1]
|
||||
iy1_ = hm_pad[index - 2 - W]
|
||||
dx = 0.5 * (ix1 - ix1_)
|
||||
dy = 0.5 * (iy1 - iy1_)
|
||||
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
|
||||
dxx = ix1 - 2 * i_ + ix1_
|
||||
dyy = iy1 - 2 * i_ + iy1_
|
||||
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
||||
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
|
||||
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
||||
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
|
||||
|
||||
# --- restore to input image space ---
|
||||
keypoints = keypoints * self.scale_factor
|
||||
keypoints[invalid] = -1
|
||||
|
||||
batch_keypoints.append(keypoints)
|
||||
batch_scores.append(scores)
|
||||
|
||||
return batch_keypoints, batch_scores
|
||||
@@ -2,196 +2,6 @@ import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
from .model import QwenImageTransformerBlock
|
||||
|
||||
|
||||
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.has_before_proj = has_before_proj
|
||||
if has_before_proj:
|
||||
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_in_features=132,
|
||||
inner_dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.main_model_double = main_model_double
|
||||
self.injection_layers = tuple(injection_layers)
|
||||
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||
# to the reference Gen2/VideoX implementation around strength=1.
|
||||
self.hint_scale = 1.0
|
||||
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||
|
||||
self.control_blocks = torch.nn.ModuleList([])
|
||||
for i in range(num_control_blocks):
|
||||
self.control_blocks.append(
|
||||
QwenImageFunControlBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
has_before_proj=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_hint_tokens(self, hint):
|
||||
if hint is None:
|
||||
return None
|
||||
if hint.ndim == 4:
|
||||
hint = hint.unsqueeze(2)
|
||||
|
||||
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||
if hint.shape[1] == 16 and expected_c == 33:
|
||||
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||
zeros_inpaint = torch.zeros_like(hint)
|
||||
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||
|
||||
bs, c, t, h, w = hint.shape
|
||||
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
orig_shape[0],
|
||||
orig_shape[1],
|
||||
orig_shape[-3],
|
||||
orig_shape[-2] // 2,
|
||||
2,
|
||||
orig_shape[-1] // 2,
|
||||
2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(
|
||||
bs,
|
||||
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||
c * 4,
|
||||
)
|
||||
|
||||
expected_in = self.control_img_in.weight.shape[1]
|
||||
cur_in = hidden_states.shape[-1]
|
||||
if cur_in < expected_in:
|
||||
pad = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||
elif cur_in > expected_in:
|
||||
hidden_states = hidden_states[:, :, :expected_in]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
base_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_model is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||
encoder_hidden_states_mask = None
|
||||
|
||||
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||
hint_tokens = self._process_hint_tokens(hint)
|
||||
if hint_tokens is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||
|
||||
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||
hint_tokens = hint_tokens[:, :max_tokens]
|
||||
hidden_states = hidden_states[:, :max_tokens]
|
||||
img_ids = img_ids[:, :max_tokens]
|
||||
|
||||
txt_start = round(
|
||||
max(
|
||||
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
)
|
||||
)
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
|
||||
hidden_states = base_model.img_in(hidden_states)
|
||||
encoder_hidden_states = base_model.txt_norm(context)
|
||||
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
base_model.time_text_embed(timesteps, hidden_states)
|
||||
if guidance is None
|
||||
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||
)
|
||||
|
||||
c = self.control_img_in(hint_tokens)
|
||||
|
||||
for i, block in enumerate(self.control_blocks):
|
||||
if i == 0:
|
||||
c_in = block.before_proj(c) + hidden_states
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c, dim=0))
|
||||
c_in = all_c.pop(-1)
|
||||
|
||||
encoder_hidden_states, c_out = block(
|
||||
hidden_states=c_in,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||
all_c += [c_skip, c_out]
|
||||
c = torch.stack(all_c, dim=0)
|
||||
|
||||
hints = torch.unbind(c, dim=0)[:-1]
|
||||
|
||||
controlnet_block_samples = [None] * self.main_model_double
|
||||
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||
|
||||
return {"input": controlnet_block_samples}
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
|
||||
@@ -1621,118 +1621,3 @@ class HumoWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class SCAILWanModel(WanModel):
|
||||
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
del scail_x
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.cat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if scail_pose_seq_len > 0:
|
||||
x = x[:, :-scail_pose_seq_len]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
|
||||
if reference_latent is not None:
|
||||
x = x[:, :, reference_latent.shape[2]:]
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
return main_freqs
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
|
||||
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
|
||||
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
reference_latent = None
|
||||
if "reference_latent" in kwargs:
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
@@ -459,7 +459,6 @@ class WanVAE(nn.Module):
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
conv_out_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -475,7 +474,7 @@ class WanVAE(nn.Module):
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
@@ -485,7 +484,7 @@ class WanVAE(nn.Module):
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
@@ -332,13 +332,6 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep15):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@@ -375,31 +368,6 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_shape(patches, weight, key, original_weights=None):
|
||||
current_shape = weight.shape
|
||||
|
||||
for p in patches:
|
||||
v = p[1]
|
||||
offset = p[3]
|
||||
|
||||
# Offsets restore the old shape; lists force a diff without metadata
|
||||
if offset is not None or isinstance(v, list):
|
||||
continue
|
||||
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
adapter_shape = v.calculate_shape(key)
|
||||
if adapter_shape is not None:
|
||||
current_shape = adapter_shape
|
||||
continue
|
||||
|
||||
# Standard diff logic with padding
|
||||
if len(v) == 2:
|
||||
patch_type, patch_data = v[0], v[1]
|
||||
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||
current_shape = patch_data[0].shape
|
||||
|
||||
return current_shape
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
|
||||
@@ -5,7 +5,7 @@ import comfy.utils
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
def element_size(self):
|
||||
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
||||
return info.bits // 8
|
||||
|
||||
def numel(self):
|
||||
return math.prod(self.shape)
|
||||
|
||||
def tensors_to_geometries(tensors, dtype=None):
|
||||
geometries = []
|
||||
for t in tensors:
|
||||
if t is None or isinstance(t, QuantizedTensor):
|
||||
geometries.append(t)
|
||||
continue
|
||||
tdtype = t.dtype
|
||||
if hasattr(t, "_model_dtype"):
|
||||
tdtype = t._model_dtype
|
||||
if dtype is not None:
|
||||
tdtype = dtype
|
||||
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
||||
return geometries
|
||||
|
||||
def vram_aligned_size(tensor):
|
||||
if isinstance(tensor, list):
|
||||
return sum([vram_aligned_size(t) for t in tensor])
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||
|
||||
if tensor is None:
|
||||
return 0
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
aligment_req = 1024
|
||||
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||
|
||||
def interpret_gathered_like(tensors, gathered):
|
||||
offset = 0
|
||||
dest_views = []
|
||||
|
||||
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||
|
||||
for tensor in tensors:
|
||||
|
||||
if tensor is None:
|
||||
dest_views.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||
else:
|
||||
templates = { "data": tensor }
|
||||
|
||||
actuals = {}
|
||||
for attr, template in templates.items():
|
||||
size = template.numel() * template.element_size()
|
||||
if offset + size > gathered.numel():
|
||||
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||
offset += vram_aligned_size(template)
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||
else:
|
||||
dest_views.append(actuals["data"])
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_enabled = False
|
||||
@@ -50,7 +50,6 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -76,7 +75,6 @@ class ModelType(Enum):
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@@ -109,8 +107,6 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -150,8 +146,6 @@ class BaseModel(torch.nn.Module):
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@@ -181,7 +175,10 @@ class BaseModel(torch.nn.Module):
|
||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype_inference()
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
@@ -218,13 +215,6 @@ class BaseModel(torch.nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def get_dtype_inference(self):
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
return dtype
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -309,7 +299,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix="", assign=False):
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
@@ -317,7 +307,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@@ -332,7 +322,7 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
extra_sds = []
|
||||
if clip_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||
@@ -340,7 +330,10 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||
if clip_vision_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@@ -379,7 +372,9 @@ class BaseModel(torch.nn.Module):
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype_inference()
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
@@ -781,8 +776,8 @@ class StableAudio1(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
@@ -925,25 +920,6 @@ class Flux(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class LongCatImage(Flux):
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
transformer_options = transformer_options.copy()
|
||||
rope_opts = transformer_options.get("rope_options", {})
|
||||
rope_opts = dict(rope_opts)
|
||||
rope_opts.setdefault("shift_t", 1.0)
|
||||
rope_opts.setdefault("shift_y", 512.0)
|
||||
rope_opts.setdefault("shift_x", 512.0)
|
||||
transformer_options["rope_options"] = rope_opts
|
||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out.pop('guidance', None)
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -993,10 +969,6 @@ class LTXV(BaseModel):
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
@@ -1014,14 +986,10 @@ class LTXAV(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
@@ -1049,10 +1017,6 @@ class LTXAV(BaseModel):
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@@ -1196,16 +1160,12 @@ class Anima(BaseModel):
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@@ -1263,11 +1223,6 @@ class Lumina2(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class ZImagePixelSpace(Lumina2):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@@ -1501,50 +1456,6 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
class WAN21_SCAIL(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
pose_latents = self.process_latent_in(pose_latents)
|
||||
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
@@ -1630,49 +1541,6 @@ class ACEStep(BaseModel):
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class ACEStep15(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
noise = kwargs["noise"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if torch.count_nonzero(cross_attn) == 0:
|
||||
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
|
||||
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
|
||||
if refer_audio is None or len(refer_audio) == 0:
|
||||
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
out['is_covers'] = comfy.conds.CONDConstant(True)
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
else:
|
||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
|
||||
if refer_audio.shape[2] < noise.shape[2]:
|
||||
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user