mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 10:40:04 +00:00
Compare commits
301 Commits
v0.4.0
...
pysssss/gl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2d229a786 | ||
|
|
3c40ee0f02 | ||
|
|
43034b6881 | ||
|
|
bb048d4aaa | ||
|
|
7c1f02d1fa | ||
|
|
292a5918f4 | ||
|
|
0050b66a0b | ||
|
|
0c313f5293 | ||
|
|
1fcf9dca18 | ||
|
|
3b790d24d6 | ||
|
|
92b2b7198a | ||
|
|
8aabe2403e | ||
|
|
309c3e4ec0 | ||
|
|
23591d4388 | ||
|
|
0167653781 | ||
|
|
c3d07bec6d | ||
|
|
0a7993729c | ||
|
|
bbe2c13a70 | ||
|
|
59b955ff54 | ||
|
|
1263d6fe88 | ||
|
|
3aace5c8dc | ||
|
|
b0d9708974 | ||
|
|
23572c6314 | ||
|
|
d809ef8fb1 | ||
|
|
a4317314d2 | ||
|
|
c9b633d84f | ||
|
|
aaea976f36 | ||
|
|
cee092213e | ||
|
|
3da0e9c367 | ||
|
|
9fa8202620 | ||
|
|
b4438c9baf | ||
|
|
1711020904 | ||
|
|
d9b8567547 | ||
|
|
6c5f906bf2 | ||
|
|
4f5bd39b1c | ||
|
|
dcff27fe3f | ||
|
|
09725967cf | ||
|
|
5f62440fbb | ||
|
|
ac91c340f4 | ||
|
|
2db3b0ff90 | ||
|
|
6516ab335d | ||
|
|
ad53e78f11 | ||
|
|
29011ba87e | ||
|
|
cd4985e2f3 | ||
|
|
bfe31d0b9d | ||
|
|
2129e7d278 | ||
|
|
7ee77ff038 | ||
|
|
26c5bbb875 | ||
|
|
a97c98068f | ||
|
|
635406e283 | ||
|
|
ed6002cb60 | ||
|
|
bc72d7f8d1 | ||
|
|
aef4e13588 | ||
|
|
4e6a1b66a9 | ||
|
|
9cf299a9f9 | ||
|
|
e89b22993a | ||
|
|
55bd606e92 | ||
|
|
cc30293d65 | ||
|
|
866d863128 | ||
|
|
79cdbc81cb | ||
|
|
f443b9f2ca | ||
|
|
4e3038114a | ||
|
|
bbb8864778 | ||
|
|
d7f3241bf6 | ||
|
|
09a2e67151 | ||
|
|
0fd1b78736 | ||
|
|
8490eedadf | ||
|
|
72f6be1690 | ||
|
|
16b9aabd52 | ||
|
|
245f6139b6 | ||
|
|
3365ad18a5 | ||
|
|
f09904720d | ||
|
|
abe2ec26a6 | ||
|
|
bdeac8897e | ||
|
|
451af70154 | ||
|
|
0fc15700be | ||
|
|
e755268e7b | ||
|
|
c4a14df9a3 | ||
|
|
965d0ed509 | ||
|
|
ddc541ffda | ||
|
|
8ccc0c94fa | ||
|
|
4edb87aa50 | ||
|
|
0fc3b6e3a6 | ||
|
|
2108167f9f | ||
|
|
9d273d3ab1 | ||
|
|
70c91b8248 | ||
|
|
0da5a0fe58 | ||
|
|
e0eacb0688 | ||
|
|
7458e20465 | ||
|
|
b931b37e30 | ||
|
|
866a4619db | ||
|
|
1a72bf2046 | ||
|
|
034fac7054 | ||
|
|
a498556d0d | ||
|
|
f7ca41ff62 | ||
|
|
ac26065e61 | ||
|
|
190c4416cc | ||
|
|
0fd10ffa09 | ||
|
|
00c775950a | ||
|
|
7ac999bf30 | ||
|
|
0c6b36c6ac | ||
|
|
9125613b53 | ||
|
|
732b707397 | ||
|
|
4c816d5c69 | ||
|
|
6125b3a5e7 | ||
|
|
12918a5f78 | ||
|
|
8f40b43e02 | ||
|
|
3b832231bb | ||
|
|
be518db5a7 | ||
|
|
80441eb15e | ||
|
|
07f2462eae | ||
|
|
d150440466 | ||
|
|
6165c38cb5 | ||
|
|
712cca36a1 | ||
|
|
ac4d8ea9b3 | ||
|
|
c9196f355e | ||
|
|
7eb959ce93 | ||
|
|
469dd9c16a | ||
|
|
eff2b9d412 | ||
|
|
15b312de7a | ||
|
|
1419047fdb | ||
|
|
79f6bb5e4f | ||
|
|
e4b4fb3479 | ||
|
|
d9dc02a7d6 | ||
|
|
c543ad81c3 | ||
|
|
5ac1372533 | ||
|
|
1dcbd9efaf | ||
|
|
db9e6edfa1 | ||
|
|
8af13b439b | ||
|
|
acd0e53653 | ||
|
|
117e7a5853 | ||
|
|
b3c0e4de57 | ||
|
|
ecaeeb990d | ||
|
|
c2b65e2fce | ||
|
|
fd5c0755af | ||
|
|
c881a1d689 | ||
|
|
a3b5d4996a | ||
|
|
c6238047ee | ||
|
|
5cd1113236 | ||
|
|
2f642d5d9b | ||
|
|
cd912963f1 | ||
|
|
6e4b1f9d00 | ||
|
|
dc202a2e51 | ||
|
|
153bc524bf | ||
|
|
393d2880dd | ||
|
|
4484b93d61 | ||
|
|
bd0e6825e8 | ||
|
|
ec0a832acb | ||
|
|
04c49a29b4 | ||
|
|
4609fcd260 | ||
|
|
6207f86c18 | ||
|
|
1dc3da6314 | ||
|
|
114fc73685 | ||
|
|
b48d6a83d4 | ||
|
|
027042db68 | ||
|
|
1a20656448 | ||
|
|
0f11869d55 | ||
|
|
5943fbf457 | ||
|
|
a60b7b86c5 | ||
|
|
2e9d51680a | ||
|
|
50d6e1caf4 | ||
|
|
ac12f77bed | ||
|
|
fcd9a236b0 | ||
|
|
21e8425087 | ||
|
|
b6c79a648a | ||
|
|
25bc1b5b57 | ||
|
|
3cd19e99c1 | ||
|
|
007b87e7ac | ||
|
|
34751fe9f9 | ||
|
|
1c705f7bfb | ||
|
|
48e5ea1dfd | ||
|
|
3cd7b32f1b | ||
|
|
c0c9720d77 | ||
|
|
fc0cb10bcb | ||
|
|
b7d7cc1d49 | ||
|
|
79e94544bd | ||
|
|
ce0000c4f2 | ||
|
|
c5cfb34c07 | ||
|
|
edee33f55e | ||
|
|
2c03884f5f | ||
|
|
6e9ee55cdd | ||
|
|
023cf13721 | ||
|
|
c3566c0d76 | ||
|
|
c3c3e93c5b | ||
|
|
6ffc159bdd | ||
|
|
96e0d0924e | ||
|
|
e14f3b6610 | ||
|
|
1618002411 | ||
|
|
6ef85c4915 | ||
|
|
6da00dd899 | ||
|
|
4f3f9e72a9 | ||
|
|
d157c3299d | ||
|
|
d1b9822f74 | ||
|
|
f2b002372b | ||
|
|
38d0493825 | ||
|
|
acbf08cd60 | ||
|
|
53e762a3af | ||
|
|
9a552df898 | ||
|
|
f2fda021ab | ||
|
|
303b1735f8 | ||
|
|
9e5f677746 | ||
|
|
65cfcf5b1b | ||
|
|
1bdc9a947f | ||
|
|
d622a61874 | ||
|
|
236b9e211d | ||
|
|
6ca3d5c011 | ||
|
|
0be8a76c93 | ||
|
|
0357ed7ec4 | ||
|
|
f59f71cf34 | ||
|
|
178bdc5e14 | ||
|
|
25a1bfab4e | ||
|
|
d7111e426a | ||
|
|
0e6221cc79 | ||
|
|
9ca7e143af | ||
|
|
8fd07170f1 | ||
|
|
2943093a53 | ||
|
|
36deef2c57 | ||
|
|
0d2e4bdd44 | ||
|
|
eff4ea0b62 | ||
|
|
865568b7fc | ||
|
|
1e4e342f54 | ||
|
|
16fb6849d2 | ||
|
|
d9a76cf66e | ||
|
|
532e285079 | ||
|
|
4f067b07fb | ||
|
|
650e716dda | ||
|
|
e4c61d7555 | ||
|
|
22ff1bbfcb | ||
|
|
f4f44bb807 | ||
|
|
33aa808713 | ||
|
|
eb0e10aec4 | ||
|
|
c176b214cc | ||
|
|
91bf6b6aa3 | ||
|
|
807538fe6c | ||
|
|
bbb11e2608 | ||
|
|
0899012ad6 | ||
|
|
fb478f679a | ||
|
|
4c432c11ed | ||
|
|
31e961736a | ||
|
|
767ee30f21 | ||
|
|
3ab9748903 | ||
|
|
0aa7fa464e | ||
|
|
514c24d756 | ||
|
|
809ce68749 | ||
|
|
cc4ddba1b6 | ||
|
|
8376ff6831 | ||
|
|
5b4d0664c8 | ||
|
|
894802b0f9 | ||
|
|
28eaab608b | ||
|
|
6a2678ac65 | ||
|
|
e4fb3a3572 | ||
|
|
e8ebbe668e | ||
|
|
1ca89b810e | ||
|
|
bf7dc63bd6 | ||
|
|
86dbb89fc9 | ||
|
|
ba6080bbab | ||
|
|
16d85ea133 | ||
|
|
5d9ad0c6bf | ||
|
|
c08f97f344 | ||
|
|
887143854b | ||
|
|
3a5f239cb6 | ||
|
|
827bb1512b | ||
|
|
ffdd53b327 | ||
|
|
65e2103b09 | ||
|
|
9304e47351 | ||
|
|
bc606d7d64 | ||
|
|
645ee1881e | ||
|
|
3d082c3206 | ||
|
|
683569de55 | ||
|
|
ea2c117bc3 | ||
|
|
fc4af86068 | ||
|
|
41bcf0619d | ||
|
|
d02d0e5744 | ||
|
|
70541d4e77 | ||
|
|
77b2f7c228 | ||
|
|
43e0d4e3cc | ||
|
|
dbd330454a | ||
|
|
33c7f1179d | ||
|
|
af91eb6c99 | ||
|
|
5cb1e0c9a0 | ||
|
|
51347f9fb8 | ||
|
|
a5e85017d8 | ||
|
|
5ac3b26a7d | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
da2bfb5b0a | ||
|
|
c5a47a1692 | ||
|
|
908fd7d749 | ||
|
|
5495589db3 | ||
|
|
982876d59a | ||
|
|
338d9ae3bb | ||
|
|
eeb020b9b7 | ||
|
|
ae65433a60 | ||
|
|
fdebe18296 | ||
|
|
f8321eb57b | ||
|
|
93948e3fc5 | ||
|
|
e711aaf1a7 | ||
|
|
57ddb7fd13 | ||
|
|
17c92a9f28 | ||
|
|
36357bbcc3 | ||
|
|
f668c2e3c9 |
@@ -53,6 +53,16 @@ try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
|
||||
6
.github/workflows/release-stable-all.yml
vendored
6
.github/workflows/release-stable-all.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
python_patch: "11"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
name: "Release AMD ROCm 7.2"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
cache_tag: "rocm72"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -117,7 +117,7 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
|
||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
1
.github/workflows/test-ci.yml
vendored
1
.github/workflows/test-ci.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
|
||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@@ -2,9 +2,9 @@ name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
10
.github/workflows/test-launch.yml
vendored
10
.github/workflows/test-launch.yml
vendored
@@ -2,9 +2,9 @@ name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
repository: "Comfy-Org/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
@@ -32,7 +32,9 @@ jobs:
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
|
||||
cat console_output_filtered.log
|
||||
if grep -qE "Exception|Error" console_output_filtered.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@@ -2,9 +2,9 @@ name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
59
.github/workflows/update-ci-container.yml
vendored
Normal file
59
.github/workflows/update-ci-container.yml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
name: "CI: Update CI Container"
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'ComfyUI version (e.g., v0.7.0)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-ci-container:
|
||||
runs-on: ubuntu-latest
|
||||
# Skip pre-releases unless manually triggered
|
||||
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
|
||||
steps:
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
VERSION="${{ github.event.release.tag_name }}"
|
||||
else
|
||||
VERSION="${{ inputs.version }}"
|
||||
fi
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout comfyui-ci-container
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: comfy-org/comfyui-ci-container
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
|
||||
- name: Check current version
|
||||
id: current
|
||||
run: |
|
||||
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
|
||||
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update Dockerfile
|
||||
run: |
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
|
||||
|
||||
- name: Create Pull Request
|
||||
id: create-pr
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
branch: automation/comfyui-${{ steps.version.outputs.version }}
|
||||
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
body: |
|
||||
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
|
||||
|
||||
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
|
||||
|
||||
labels: automation
|
||||
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
|
||||
15
README.md
15
README.md
@@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
@@ -180,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
@@ -205,10 +208,12 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
@@ -224,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
@@ -235,7 +240,7 @@ These have less hardware support than the builds above but they work on windows.
|
||||
|
||||
RDNA 3 (RX 7000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
|
||||
|
||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
||||
|
||||
|
||||
174
alembic_db/versions/0001_assets.py
Normal file
174
alembic_db/versions/0001_assets.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Initial assets schema
|
||||
Revision ID: 0001_assets
|
||||
Revises: None
|
||||
Create Date: 2025-12-10 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
@@ -58,8 +58,13 @@ class InternalRoutes:
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
102
app/assets/api/routes.py
Normal file
102
app/assets/api/routes.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import uuid
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list assets.
|
||||
"""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get an asset's info as JSON.
|
||||
"""
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list all tags based on query parameters.
|
||||
"""
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as e:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
94
app/assets/api/schemas_in.py
Normal file
94
app/assets/api/schemas_in.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: str | None = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
60
app/assets/api/schemas_out.py
Normal file
60
app/assets/api/schemas_out.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
204
app/assets/database/bulk_ops.py
Normal file
204
app/assets/database/bulk_ops.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
233
app/assets/database/models.py
Normal file
233
app/assets/database/models.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
267
app/assets/database/queries.py
Normal file
267
app/assets/database/queries.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
62
app/assets/database/tags.py
Normal file
62
app/assets/database/tags.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
75
app/assets/hashing.py
Normal file
75
app/assets/hashing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
217
app/assets/helpers.py
Normal file
217
app/assets/helpers.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import contextlib
|
||||
import os
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize a list of tags by:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
123
app/assets/manager.py
Normal file
123
app/assets/manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
def list_assets(
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
229
app/assets/scanner.py
Normal file
229
app/assets/scanner.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
@@ -1,14 +1,21 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
Base = declarative_base()
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
# TODO: Define models here
|
||||
|
||||
@@ -44,7 +44,7 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
@@ -55,7 +55,7 @@ class ModelFileManager:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
@@ -10,6 +10,7 @@ import hashlib
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
templates = "templates"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
@@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
|
||||
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
|
||||
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": source,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": {"node_pack": node_pack},
|
||||
}
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
@@ -60,53 +73,60 @@ class SubgraphManager:
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
"""Load subgraphs from custom nodes."""
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
pattern = os.path.join(folder, "*/subgraphs/*.json")
|
||||
for file in glob.glob(pattern):
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
node_pack = "custom_nodes." + file.split('/')[-3]
|
||||
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
async def get_blueprint_subgraphs(self, force_reload=False):
|
||||
"""Load subgraphs from the blueprints directory."""
|
||||
if not force_reload and self.cached_blueprint_subgraphs is not None:
|
||||
return self.cached_blueprint_subgraphs
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
|
||||
|
||||
if os.path.exists(blueprints_dir):
|
||||
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
|
||||
file = file.replace('\\', '/')
|
||||
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_blueprint_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_all_subgraphs(self, loadedModules, force_reload=False):
|
||||
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
|
||||
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
|
||||
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
|
||||
return {**custom_node_subgraphs, **blueprint_subgraphs}
|
||||
|
||||
async def get_subgraph(self, id: str, loadedModules):
|
||||
"""Get a specific subgraph by ID from any source."""
|
||||
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
|
||||
if entry is not None and entry.get('data') is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
subgraph = await self.get_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
|
||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@@ -0,0 +1,44 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Brightness slider -100..100
|
||||
uniform float u_float1; // Contrast slider -100..100
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float MID_GRAY = 0.18; // 18% reflectance
|
||||
|
||||
// sRGB gamma 2.2 approximation
|
||||
vec3 srgbToLinear(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(2.2));
|
||||
}
|
||||
|
||||
vec3 linearToSrgb(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||
}
|
||||
|
||||
float mapBrightness(float b) {
|
||||
return clamp(b / 100.0, -1.0, 1.0);
|
||||
}
|
||||
|
||||
float mapContrast(float c) {
|
||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 orig = texture(u_image0, v_texCoord);
|
||||
|
||||
float brightness = mapBrightness(u_float0);
|
||||
float contrast = mapContrast(u_float1);
|
||||
|
||||
vec3 lin = srgbToLinear(orig.rgb);
|
||||
|
||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||
|
||||
// Convert back to sRGB
|
||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||
|
||||
fragColor = vec4(result, orig.a);
|
||||
}
|
||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@@ -0,0 +1,72 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Mode
|
||||
uniform float u_float0; // Amount (0 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MODE_LINEAR = 0;
|
||||
const int MODE_RADIAL = 1;
|
||||
const int MODE_BARREL = 2;
|
||||
const int MODE_SWIRL = 3;
|
||||
const int MODE_DIAGONAL = 4;
|
||||
|
||||
const float AMOUNT_SCALE = 0.0005;
|
||||
const float RADIAL_MULT = 4.0;
|
||||
const float BARREL_MULT = 8.0;
|
||||
const float INV_SQRT2 = 0.70710678118;
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_texCoord;
|
||||
vec4 original = texture(u_image0, uv);
|
||||
|
||||
float amount = u_float0 * AMOUNT_SCALE;
|
||||
|
||||
if (amount < 0.000001) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
// Aspect-corrected coordinates for circular effects
|
||||
float aspect = u_resolution.x / u_resolution.y;
|
||||
vec2 centered = uv - 0.5;
|
||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||
float r = length(corrected);
|
||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||
vec2 offset = vec2(0.0);
|
||||
|
||||
if (u_int0 == MODE_LINEAR) {
|
||||
// Horizontal shift (no aspect correction needed)
|
||||
offset = vec2(amount, 0.0);
|
||||
}
|
||||
else if (u_int0 == MODE_RADIAL) {
|
||||
// Outward from center, stronger at edges
|
||||
offset = dir * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_BARREL) {
|
||||
// Lens distortion simulation (r² falloff)
|
||||
offset = dir * r * r * amount * BARREL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_SWIRL) {
|
||||
// Perpendicular to radial (rotational aberration)
|
||||
vec2 perp = vec2(-dir.y, dir.x);
|
||||
offset = perp * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_DIAGONAL) {
|
||||
// 45° offset (no aspect correction needed)
|
||||
offset = vec2(amount, amount) * INV_SQRT2;
|
||||
}
|
||||
|
||||
float red = texture(u_image0, uv + offset).r;
|
||||
float green = original.g;
|
||||
float blue = texture(u_image0, uv - offset).b;
|
||||
|
||||
fragColor = vec4(red, green, blue, original.a);
|
||||
}
|
||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@@ -0,0 +1,78 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // temperature (-100 to 100)
|
||||
uniform float u_float1; // tint (-100 to 100)
|
||||
uniform float u_float2; // vibrance (-100 to 100)
|
||||
uniform float u_float3; // saturation (-100 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float INPUT_SCALE = 0.01;
|
||||
const float TEMP_TINT_PRIMARY = 0.3;
|
||||
const float TEMP_TINT_SECONDARY = 0.15;
|
||||
const float VIBRANCE_BOOST = 2.0;
|
||||
const float SATURATION_BOOST = 2.0;
|
||||
const float SKIN_PROTECTION = 0.5;
|
||||
const float EPSILON = 0.001;
|
||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
// Scale inputs: -100/100 → -1/1
|
||||
float temperature = u_float0 * INPUT_SCALE;
|
||||
float tint = u_float1 * INPUT_SCALE;
|
||||
float vibrance = u_float2 * INPUT_SCALE;
|
||||
float saturation = u_float3 * INPUT_SCALE;
|
||||
|
||||
// Temperature (warm/cool): positive = warm, negative = cool
|
||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||
|
||||
// Tint (green/magenta): positive = green, negative = magenta
|
||||
color.g += tint * TEMP_TINT_PRIMARY;
|
||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||
|
||||
// Single clamp after temperature/tint
|
||||
color = clamp(color, 0.0, 1.0);
|
||||
|
||||
// Vibrance with skin protection
|
||||
if (vibrance != 0.0) {
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float sat = maxC - minC;
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
|
||||
if (vibrance < 0.0) {
|
||||
// Desaturate: -100 → gray
|
||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||
} else {
|
||||
// Boost less saturated colors more
|
||||
float vibranceAmt = vibrance * (1.0 - sat);
|
||||
|
||||
// Branchless skin tone protection
|
||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||
|
||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||
}
|
||||
}
|
||||
|
||||
// Saturation
|
||||
if (saturation != 0.0) {
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
float satMix = saturation < 0.0
|
||||
? 1.0 + saturation // -100 → gray
|
||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||
color = mix(vec3(gray), color, satMix);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@@ -0,0 +1,94 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MAX_RADIUS = 20;
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
// Perceptual luminance
|
||||
float getLuminance(vec3 rgb) {
|
||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
}
|
||||
|
||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||
float sigmaSpatial, float sigmaColor)
|
||||
{
|
||||
vec4 center = texture(u_image0, uv);
|
||||
vec3 centerRGB = center.rgb;
|
||||
|
||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||
|
||||
vec3 sumRGB = vec3(0.0);
|
||||
float sumWeight = 0.0;
|
||||
|
||||
int step = max(u_int0, 1);
|
||||
float radius2 = float(radius * radius);
|
||||
|
||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||
if (dy < -radius || dy > radius) continue;
|
||||
if (abs(dy) % step != 0) continue;
|
||||
|
||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||
if (dx < -radius || dx > radius) continue;
|
||||
if (abs(dx) % step != 0) continue;
|
||||
|
||||
vec2 offset = vec2(float(dx), float(dy));
|
||||
float dist2 = dot(offset, offset);
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||
|
||||
// Spatial Gaussian
|
||||
float spatialWeight = exp(dist2 * invSpatial2);
|
||||
|
||||
// Perceptual color distance (weighted RGB)
|
||||
vec3 diff = sampleRGB - centerRGB;
|
||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||
float colorWeight = exp(colorDist * invColor2);
|
||||
|
||||
float w = spatialWeight * colorWeight;
|
||||
sumRGB += sampleRGB * w;
|
||||
sumWeight += w;
|
||||
}
|
||||
}
|
||||
|
||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||
return vec4(resultRGB, center.a); // preserve center alpha
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||
int radius = int(radiusF + 0.5);
|
||||
|
||||
if (radius == 0) {
|
||||
fragColor = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Edge threshold → color sigma
|
||||
// Squared curve for better low-end control
|
||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||
t *= t;
|
||||
float sigmaColor = mix(0.01, 0.5, t);
|
||||
|
||||
// Spatial sigma tied to radius
|
||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||
|
||||
fragColor = bilateralFilter(
|
||||
v_texCoord,
|
||||
texelSize,
|
||||
radius,
|
||||
sigmaSpatial,
|
||||
sigmaColor
|
||||
);
|
||||
}
|
||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@@ -0,0 +1,124 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// High-quality integer hash (pcg-like)
|
||||
uint pcg(uint v) {
|
||||
uint state = v * 747796405u + 2891336453u;
|
||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||
return (word >> 22u) ^ word;
|
||||
}
|
||||
|
||||
// 2D -> 1D hash input
|
||||
uint hash2d(uvec2 p) {
|
||||
return pcg(p.x + pcg(p.y));
|
||||
}
|
||||
|
||||
// Hash to float [0, 1]
|
||||
float hashf(uvec2 p) {
|
||||
return float(hash2d(p)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Hash to float with offset (for RGB channels)
|
||||
float hashf(uvec2 p, uint offset) {
|
||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||
// Using simple approximation: average of multiple samples
|
||||
float toGaussian(uvec2 p) {
|
||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||
}
|
||||
|
||||
float toGaussian(uvec2 p, uint offset) {
|
||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||
return (sum - 2.0) * 0.7;
|
||||
}
|
||||
|
||||
// Smooth noise with better interpolation
|
||||
float smoothNoise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
// Quintic interpolation (less banding than cubic)
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
float smoothNoise(vec2 p, uint offset) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui, offset);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// Luminance (Rec.709)
|
||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||
|
||||
// Grain UV (resolution-independent)
|
||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||
uvec2 grainPixel = uvec2(grainUV);
|
||||
|
||||
float g;
|
||||
vec3 grainRGB;
|
||||
|
||||
if (u_int0 == 1) {
|
||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||
g = toGaussian(grainPixel);
|
||||
grainRGB = vec3(
|
||||
toGaussian(grainPixel, 100u),
|
||||
toGaussian(grainPixel, 200u),
|
||||
toGaussian(grainPixel, 300u)
|
||||
);
|
||||
} else {
|
||||
// Smooth mode: interpolated with quintic curve
|
||||
g = smoothNoise(grainUV);
|
||||
grainRGB = vec3(
|
||||
smoothNoise(grainUV, 100u),
|
||||
smoothNoise(grainUV, 200u),
|
||||
smoothNoise(grainUV, 300u)
|
||||
);
|
||||
}
|
||||
|
||||
// Luminance weighting (less grain in highlights)
|
||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||
|
||||
// Strength
|
||||
float strength = u_float0 * 0.15;
|
||||
|
||||
// Color vs monochrome grain
|
||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||
|
||||
color.rgb += grainColor * strength * lumWeight;
|
||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||
}
|
||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@@ -0,0 +1,133 @@
|
||||
#version 300 es
|
||||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
uniform float u_float1; // Radius
|
||||
uniform float u_float2; // Threshold
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int BLEND_ADD = 0;
|
||||
const int BLEND_SCREEN = 1;
|
||||
const int BLEND_SOFT = 2;
|
||||
const int BLEND_OVERLAY = 3;
|
||||
const int BLEND_LIGHTEN = 4;
|
||||
|
||||
const float GOLDEN_ANGLE = 2.39996323;
|
||||
const int MAX_SAMPLES = 48;
|
||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
float hash(vec2 p) {
|
||||
p = fract(p * vec2(123.34, 456.21));
|
||||
p += dot(p, p + 45.32);
|
||||
return fract(p.x * p.y);
|
||||
}
|
||||
|
||||
vec3 hexToRgb(int h) {
|
||||
return vec3(
|
||||
float((h >> 16) & 255),
|
||||
float((h >> 8) & 255),
|
||||
float(h & 255)
|
||||
) * (1.0 / 255.0);
|
||||
}
|
||||
|
||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||
if (mode == BLEND_SCREEN) {
|
||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||
}
|
||||
if (mode == BLEND_SOFT) {
|
||||
return mix(
|
||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||
step(0.5, glow)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_OVERLAY) {
|
||||
return mix(
|
||||
2.0 * base * glow,
|
||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||
step(0.5, base)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_LIGHTEN) {
|
||||
return max(base, glow);
|
||||
}
|
||||
return base + glow;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float intensity = u_float0 * 0.05;
|
||||
float radius = u_float1 * u_float1 * 0.012;
|
||||
|
||||
if (intensity < 0.001 || radius < 0.1) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
float threshold = 1.0 - u_float2 * 0.01;
|
||||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||
|
||||
float noise = hash(gl_FragCoord.xy);
|
||||
float angleOffset = noise * GOLDEN_ANGLE;
|
||||
float radiusJitter = 0.85 + noise * 0.3;
|
||||
|
||||
float ca = cos(GOLDEN_ANGLE);
|
||||
float sa = sin(GOLDEN_ANGLE);
|
||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||
|
||||
vec3 glow = vec3(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
// Center tap
|
||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||
glow += original.rgb * centerMask * 2.0;
|
||||
totalWeight += 2.0;
|
||||
|
||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||
if (i >= samples) break;
|
||||
|
||||
float fi = float(i);
|
||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||
|
||||
vec2 offset = dir * dist * texelSize;
|
||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||
|
||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||
w = max(w, 0.0);
|
||||
w *= w;
|
||||
|
||||
glow += c * mask * w;
|
||||
totalWeight += w;
|
||||
|
||||
dir = vec2(
|
||||
dir.x * ca - dir.y * sa,
|
||||
dir.x * sa + dir.y * ca
|
||||
);
|
||||
}
|
||||
|
||||
glow *= intensity / max(totalWeight, 0.001);
|
||||
|
||||
if (u_int1 > 0) {
|
||||
glow *= hexToRgb(u_int1);
|
||||
}
|
||||
|
||||
vec3 result = blend(original.rgb, glow, u_int0);
|
||||
result += (noise - 0.5) * (1.0 / 255.0);
|
||||
|
||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||
}
|
||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@@ -0,0 +1,222 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||
uniform float u_float0; // Hue (-180 to 180)
|
||||
uniform float u_float1; // Saturation (-100 to 100)
|
||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
// Color range modes
|
||||
const int MODE_MASTER = 0;
|
||||
const int MODE_RED = 1;
|
||||
const int MODE_YELLOW = 2;
|
||||
const int MODE_GREEN = 3;
|
||||
const int MODE_CYAN = 4;
|
||||
const int MODE_BLUE = 5;
|
||||
const int MODE_MAGENTA = 6;
|
||||
const int MODE_COLORIZE = 7;
|
||||
|
||||
// Color space modes
|
||||
const int COLORSPACE_HSL = 0;
|
||||
const int COLORSPACE_HSB = 1;
|
||||
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
//=============================================================================
|
||||
// RGB <-> HSL Conversions
|
||||
//=============================================================================
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = 0.0;
|
||||
float l = (maxC + minC) * 0.5;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
s = l < 0.5
|
||||
? delta / (maxC + minC)
|
||||
: delta / (2.0 - maxC - minC);
|
||||
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
t = fract(t);
|
||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 0.5) return q;
|
||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||
|
||||
float q = hsl.z < 0.5
|
||||
? hsl.z * (1.0 + hsl.y)
|
||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||
float p = 2.0 * hsl.z - q;
|
||||
|
||||
return vec3(
|
||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||
hue2rgb(p, q, hsl.x),
|
||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||
);
|
||||
}
|
||||
|
||||
vec3 rgb2hsb(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||
float b = maxC;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, b);
|
||||
}
|
||||
|
||||
vec3 hsb2rgb(vec3 hsb) {
|
||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Color Range Weight Calculation
|
||||
//=============================================================================
|
||||
|
||||
float hueDistance(float a, float b) {
|
||||
float d = abs(a - b);
|
||||
return min(d, 1.0 - d);
|
||||
}
|
||||
|
||||
float getHueWeight(float hue, float center, float overlap) {
|
||||
float baseWidth = 1.0 / 6.0;
|
||||
float feather = baseWidth * overlap;
|
||||
|
||||
float d = hueDistance(hue, center);
|
||||
|
||||
float inner = baseWidth * 0.5;
|
||||
float outer = inner + feather;
|
||||
|
||||
return 1.0 - smoothstep(inner, outer, d);
|
||||
}
|
||||
|
||||
float getModeWeight(float hue, int mode, float overlap) {
|
||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||
|
||||
if (mode == MODE_RED) {
|
||||
return max(
|
||||
getHueWeight(hue, 0.0, overlap),
|
||||
getHueWeight(hue, 1.0, overlap)
|
||||
);
|
||||
}
|
||||
|
||||
float center = float(mode - 1) / 6.0;
|
||||
return getHueWeight(hue, center, overlap);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Adjustment Functions
|
||||
//=============================================================================
|
||||
|
||||
float adjustLightness(float l, float amount) {
|
||||
return amount > 0.0
|
||||
? l + (1.0 - l) * amount
|
||||
: l + l * amount;
|
||||
}
|
||||
|
||||
float adjustBrightness(float b, float amount) {
|
||||
return clamp(b + amount, 0.0, 1.0);
|
||||
}
|
||||
|
||||
float adjustSaturation(float s, float amount) {
|
||||
return amount > 0.0
|
||||
? s + (1.0 - s) * amount
|
||||
: s + s * amount;
|
||||
}
|
||||
|
||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
float l = adjustLightness(lum, light);
|
||||
|
||||
vec3 hsl = vec3(fract(hue), clamp(abs(sat), 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||
return hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Main
|
||||
//=============================================================================
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == MODE_COLORIZE) {
|
||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||
fragColor = vec4(result, original.a);
|
||||
return;
|
||||
}
|
||||
|
||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||
? rgb2hsl(original.rgb)
|
||||
: rgb2hsb(original.rgb);
|
||||
|
||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||
|
||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||
weight = 0.0;
|
||||
}
|
||||
|
||||
if (weight > EPSILON) {
|
||||
float h = fract(hsx.x + hueShift * weight);
|
||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||
float v = (u_int1 == COLORSPACE_HSL)
|
||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||
|
||||
vec3 adjusted = vec3(h, s, v);
|
||||
result = (u_int1 == COLORSPACE_HSL)
|
||||
? hsl2rgb(adjusted)
|
||||
: hsb2rgb(adjusted);
|
||||
} else {
|
||||
result = original.rgb;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, original.a);
|
||||
}
|
||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@@ -0,0 +1,111 @@
|
||||
#version 300 es
|
||||
#pragma passes 2
|
||||
precision highp float;
|
||||
|
||||
// Blur type constants
|
||||
const int BLUR_GAUSSIAN = 0;
|
||||
const int BLUR_BOX = 1;
|
||||
const int BLUR_RADIAL = 2;
|
||||
|
||||
// Radial blur config
|
||||
const int RADIAL_SAMPLES = 12;
|
||||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
if (u_int0 == BLUR_RADIAL) {
|
||||
// Only execute on first pass
|
||||
if (u_pass > 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec2 center = vec2(0.5);
|
||||
vec2 dir = v_texCoord - center;
|
||||
float dist = length(dir);
|
||||
|
||||
if (dist < 1e-4) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec4 sum = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float angleStep = radius * RADIAL_STRENGTH;
|
||||
|
||||
dir /= dist;
|
||||
|
||||
float cosStep = cos(angleStep);
|
||||
float sinStep = sin(angleStep);
|
||||
|
||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||
vec2 rotDir = vec2(
|
||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||
);
|
||||
|
||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||
vec2 uv = center + rotDir * dist;
|
||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||
sum += texture(u_image0, uv) * w;
|
||||
totalWeight += w;
|
||||
|
||||
rotDir = vec2(
|
||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||
rotDir.x * sinStep + rotDir.y * cosStep
|
||||
);
|
||||
}
|
||||
|
||||
fragColor0 = sum / max(totalWeight, 0.001);
|
||||
return;
|
||||
}
|
||||
|
||||
// Separable Gaussian / Box blur
|
||||
int samples = int(ceil(radius));
|
||||
|
||||
if (samples == 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||
|
||||
vec4 color = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
for (int i = -samples; i <= samples; i++) {
|
||||
vec2 offset = dir * float(i) * texelSize;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float weight;
|
||||
if (u_int0 == BLUR_GAUSSIAN) {
|
||||
weight = gaussian(float(i), sigma);
|
||||
} else {
|
||||
// BLUR_BOX
|
||||
weight = 1.0;
|
||||
}
|
||||
|
||||
color += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
|
||||
fragColor0 = color / totalWeight;
|
||||
}
|
||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@@ -0,0 +1,19 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
layout(location = 1) out vec4 fragColor1;
|
||||
layout(location = 2) out vec4 fragColor2;
|
||||
layout(location = 3) out vec4 fragColor3;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
// Output each channel as grayscale to separate render targets
|
||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||
}
|
||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@@ -0,0 +1,71 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
// Levels Adjustment
|
||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||
// u_float0: input black (0-255) default: 0
|
||||
// u_float1: input white (0-255) default: 255
|
||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||
// u_float3: output black (0-255) default: 0
|
||||
// u_float4: output white (0-255) default: 255
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, vec3(1.0 / gamma));
|
||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, 1.0 / gamma);
|
||||
result = mix(outBlack, outWhite, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 texColor = texture(u_image0, v_texCoord);
|
||||
vec3 color = texColor.rgb;
|
||||
|
||||
float inBlack = u_float0 / 255.0;
|
||||
float inWhite = u_float1 / 255.0;
|
||||
float gamma = u_float2;
|
||||
float outBlack = u_float3 / 255.0;
|
||||
float outWhite = u_float4 / 255.0;
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == 0) {
|
||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 1) {
|
||||
result = color;
|
||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 2) {
|
||||
result = color;
|
||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 3) {
|
||||
result = color;
|
||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else {
|
||||
result = color;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, texColor.a);
|
||||
}
|
||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# GLSL Shader Sources
|
||||
|
||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
`{Blueprint_Name}_{node_id}.frag`
|
||||
|
||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||
- **node_id**: The GLSLShader node ID within the subgraph
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract shaders from blueprint JSONs to this folder
|
||||
python update_blueprints.py extract
|
||||
|
||||
# Patch edited shaders back into blueprint JSONs
|
||||
python update_blueprints.py patch
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run `extract` to pull current shaders from JSONs
|
||||
2. Edit `.frag` files
|
||||
3. Run `patch` to update the blueprint JSONs
|
||||
4. Test
|
||||
5. Commit both `.frag` files and updated JSONs
|
||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@@ -0,0 +1,28 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||
|
||||
// Edge enhancement (Laplacian)
|
||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||
|
||||
// Add edges back scaled by strength
|
||||
vec4 sharpened = center + edges * u_float0;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||
}
|
||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@@ -0,0 +1,61 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
float getLuminance(vec3 color) {
|
||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
// Gaussian blur for the "unsharp" mask
|
||||
int samples = int(ceil(radius));
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
vec4 blurred = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
for (int x = -samples; x <= samples; x++) {
|
||||
for (int y = -samples; y <= samples; y++) {
|
||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float dist = length(vec2(float(x), float(y)));
|
||||
float weight = gaussian(dist, sigma);
|
||||
blurred += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
}
|
||||
blurred /= totalWeight;
|
||||
|
||||
// Unsharp mask = original - blurred
|
||||
vec3 mask = original.rgb - blurred.rgb;
|
||||
|
||||
// Luminance-based threshold with smooth falloff
|
||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||
mask *= thresholdScale;
|
||||
|
||||
// Sharpen: original + mask * amount
|
||||
vec3 sharpened = original.rgb + mask * amount;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||
}
|
||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shader Blueprint Updater
|
||||
|
||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||
|
||||
File naming convention:
|
||||
{Blueprint Name}_{node_id}.frag
|
||||
|
||||
Usage:
|
||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||
python update_blueprints.py # Same as patch (default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLSL_DIR = Path(__file__).parent
|
||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||
|
||||
|
||||
def get_blueprint_files():
|
||||
"""Get all blueprint JSON files."""
|
||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
"""Convert blueprint name to safe filename."""
|
||||
return re.sub(r'[^\w\-]', '_', name)
|
||||
|
||||
|
||||
def extract_shaders():
|
||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||
extracted = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = json_path.stem
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
# Find GLSLShader nodes in subgraphs
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('type') == 'GLSLShader':
|
||||
node_id = node.get('id')
|
||||
widgets = node.get('widgets_values', [])
|
||||
|
||||
# Find shader code (first string that looks like GLSL)
|
||||
for widget in widgets:
|
||||
if isinstance(widget, str) and widget.startswith('#version'):
|
||||
safe_name = sanitize_filename(blueprint_name)
|
||||
frag_name = f"{safe_name}_{node_id}.frag"
|
||||
frag_path = GLSL_DIR / frag_name
|
||||
|
||||
with open(frag_path, 'w') as f:
|
||||
f.write(widget)
|
||||
|
||||
logger.info(" Extracted: %s", frag_name)
|
||||
extracted += 1
|
||||
break
|
||||
|
||||
logger.info("\nExtracted %d shader(s)", extracted)
|
||||
|
||||
|
||||
def patch_shaders():
|
||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||
shader_updates = {}
|
||||
|
||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||
parts = frag_path.stem.rsplit('_', 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||
continue
|
||||
|
||||
blueprint_name, node_id_str = parts
|
||||
|
||||
try:
|
||||
node_id = int(node_id_str)
|
||||
except ValueError:
|
||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||
continue
|
||||
|
||||
with open(frag_path, 'r') as f:
|
||||
shader_code = f.read()
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
shader_updates[blueprint_name] = []
|
||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||
|
||||
# Apply updates to JSON files
|
||||
patched = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = sanitize_filename(json_path.stem)
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error("Error reading %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
modified = False
|
||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||
# Find the node and update
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||
widgets = node.get('widgets_values', [])
|
||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||
widgets[0] = shader_code
|
||||
modified = True
|
||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||
patched += 1
|
||||
|
||||
if modified:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
if patched == 0:
|
||||
logger.info("No changes to apply.")
|
||||
else:
|
||||
logger.info("\nPatched %d shader(s)", patched)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
command = "patch"
|
||||
else:
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "extract":
|
||||
logger.info("Extracting shaders from blueprints...")
|
||||
extract_shaders()
|
||||
elif command in ("patch", "update", "apply"):
|
||||
logger.info("Patching shaders into blueprints...")
|
||||
patch_shaders()
|
||||
else:
|
||||
logger.info(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}}]}}
|
||||
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision":0,"last_node_id":25,"last_link_id":0,"nodes":[{"id":25,"type":"621ba4e2-22a8-482d-a369-023753198b7b","pos":[4610,-790],"size":[230,58],"flags":{},"order":4,"mode":0,"inputs":[{"label":"image","localized_name":"images.image0","name":"images.image0","type":"IMAGE","link":null}],"outputs":[{"label":"IMAGE","localized_name":"IMAGE0","name":"IMAGE0","type":"IMAGE","links":[]}],"title":"Sharpen","properties":{"proxyWidgets":[["24","value"]]},"widgets_values":[]}],"links":[],"version":0.4,"definitions":{"subgraphs":[{"id":"621ba4e2-22a8-482d-a369-023753198b7b","version":1,"state":{"lastGroupId":0,"lastNodeId":24,"lastLinkId":36,"lastRerouteId":0},"revision":0,"config":{},"name":"Sharpen","inputNode":{"id":-10,"bounding":[4090,-825,120,60]},"outputNode":{"id":-20,"bounding":[5150,-825,120,60]},"inputs":[{"id":"37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7","name":"images.image0","type":"IMAGE","linkIds":[34],"localized_name":"images.image0","label":"image","pos":[4190,-805]}],"outputs":[{"id":"e9182b3f-635c-4cd4-a152-4b4be17ae4b9","name":"IMAGE0","type":"IMAGE","linkIds":[35],"localized_name":"IMAGE0","label":"IMAGE","pos":[5170,-805]}],"widgets":[],"nodes":[{"id":24,"type":"PrimitiveFloat","pos":[4280,-1240],"size":[270,58],"flags":{},"order":0,"mode":0,"inputs":[{"label":"strength","localized_name":"value","name":"value","type":"FLOAT","widget":{"name":"value"},"link":null}],"outputs":[{"localized_name":"FLOAT","name":"FLOAT","type":"FLOAT","links":[36]}],"properties":{"Node name for S&R":"PrimitiveFloat","min":0,"max":3,"precision":2,"step":0.05},"widgets_values":[0.5]},{"id":23,"type":"GLSLShader","pos":[4570,-1240],"size":[370,192],"flags":{},"order":1,"mode":0,"inputs":[{"label":"image0","localized_name":"images.image0","name":"images.image0","type":"IMAGE","link":34},{"label":"image1","localized_name":"images.image1","name":"images.image1","shape":7,"type":"IMAGE","link":null},{"label":"u_float0","localized_name":"floats.u_float0","name":"floats.u_float0","shape":7,"type":"FLOAT","link":36},{"label":"u_float1","localized_name":"floats.u_float1","name":"floats.u_float1","shape":7,"type":"FLOAT","link":null},{"label":"u_int0","localized_name":"ints.u_int0","name":"ints.u_int0","shape":7,"type":"INT","link":null},{"localized_name":"fragment_shader","name":"fragment_shader","type":"STRING","widget":{"name":"fragment_shader"},"link":null},{"localized_name":"size_mode","name":"size_mode","type":"COMFY_DYNAMICCOMBO_V3","widget":{"name":"size_mode"},"link":null}],"outputs":[{"localized_name":"IMAGE0","name":"IMAGE0","type":"IMAGE","links":[35]},{"localized_name":"IMAGE1","name":"IMAGE1","type":"IMAGE","links":null},{"localized_name":"IMAGE2","name":"IMAGE2","type":"IMAGE","links":null},{"localized_name":"IMAGE3","name":"IMAGE3","type":"IMAGE","links":null}],"properties":{"Node name for S&R":"GLSLShader"},"widgets_values":["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}","from_input"]}],"groups":[],"links":[{"id":36,"origin_id":24,"origin_slot":0,"target_id":23,"target_slot":2,"type":"FLOAT"},{"id":34,"origin_id":-10,"origin_slot":0,"target_id":23,"target_slot":0,"type":"IMAGE"},{"id":35,"origin_id":23,"origin_slot":0,"target_id":-20,"target_slot":0,"type":"IMAGE"}],"extra":{"workflowRendererVersion":"LG"}}]}}
|
||||
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
0
blueprints/put_blueprints_here
Normal file
0
blueprints/put_blueprints_here
Normal file
@@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
@@ -224,6 +231,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,6 +1,59 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
import math
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||
def scale_dim(size, scale):
|
||||
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||
return max(patch_size, int(scaled))
|
||||
|
||||
# Binary search for optimal scale
|
||||
lo, hi = eps / 10, 100.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||
|
||||
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||
if size > 0:
|
||||
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
@@ -156,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||
return embeds + embed_weight
|
||||
|
||||
class Siglip2Embeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, pixel_values):
|
||||
b, c, h, w = pixel_values.shape
|
||||
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||
img = self.patch_embedding(img)
|
||||
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
@@ -199,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
if model_type in ["siglip2_vision_model"]:
|
||||
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -17,28 +16,12 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
@@ -50,9 +33,10 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.model_type = config.get("model_type", "clip_vision_model")
|
||||
self.config = config.copy()
|
||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||
if self.model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
@@ -73,12 +57,16 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
if self.model_type == "siglip2_vision_model":
|
||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
else:
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0]
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
@@ -125,10 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||
if len(patch_embedding_shape) == 2:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||
else:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
|
||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": -1,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip2_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"num_patches": 256,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
@@ -236,6 +236,8 @@ class ComfyNodeABC(ABC):
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
DEV_ONLY: bool
|
||||
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ class IndexListCallbacks:
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
@@ -142,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
@@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
@@ -175,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# Handle vace_context (temporal dim is 3)
|
||||
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
vace_cond = cond_value.cond
|
||||
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
|
||||
144
comfy/float.py
144
comfy/float.py
@@ -65,3 +65,147 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros(
|
||||
(padded_rows, padded_cols),
|
||||
device=input_matrix.device,
|
||||
dtype=input_matrix.dtype,
|
||||
)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
if flatten:
|
||||
return rearranged.flatten()
|
||||
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
|
||||
|
||||
x = x.view(orig_shape).nan_to_num()
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
return data_lp, scaled_block_scales_fp8
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
|
||||
return x, to_blocked(blocked_scaled, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
orig_shape = list(orig_shape)
|
||||
|
||||
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
|
||||
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
num_slices = max(1, (x.numel() / block_size))
|
||||
slice_size = max(1, (round(x.shape[0] / num_slices)))
|
||||
|
||||
for i in range(0, x.shape[0], slice_size):
|
||||
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
|
||||
output_fp4[i:i + slice_size].copy_(fp4)
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
@@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
else:
|
||||
break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
||||
@@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
if x.device == torch.device("cpu"):
|
||||
seed += 1
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
@@ -1557,10 +1560,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@@ -1600,8 +1606,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
@@ -1609,6 +1621,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
@@ -1756,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
x_pred = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
@@ -1777,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
return x_pred
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -8,6 +8,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@@ -80,6 +81,7 @@ class SD_X4(LatentFormat):
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
latent_channels = 16
|
||||
spacial_downscale_ratio = 42
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
@@ -102,6 +104,7 @@ class SC_Prior(LatentFormat):
|
||||
]
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
spacial_downscale_ratio = 4
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
@@ -181,6 +184,7 @@ class Flux(SD3):
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
@@ -272,6 +276,7 @@ class Mochi(LatentFormat):
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@@ -407,6 +412,11 @@ class LTXV(LatentFormat):
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class LTXAV(LTXV):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = None
|
||||
self.latent_rgb_factors_bias = None
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
@@ -510,6 +520,7 @@ class Wan21(LatentFormat):
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
@@ -587,6 +598,7 @@ class Wan22(Wan21):
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
spacial_downscale_ratio = 32
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
@@ -720,6 +732,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@@ -744,6 +757,7 @@ class ACEAudio(LatentFormat):
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
|
||||
202
comfy/ldm/anima/model.py
Normal file
202
comfy/ldm/anima/model.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, head_dim):
|
||||
super().__init__()
|
||||
self.rope_theta = 10000
|
||||
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = head_dim * n_heads
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
|
||||
context = x if context is None else context
|
||||
input_shape = x.shape[:-1]
|
||||
q_shape = (*input_shape, self.n_heads, self.head_dim)
|
||||
context_shape = context.shape[:-1]
|
||||
kv_shape = (*context_shape, self.n_heads, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
assert position_embeddings_context is not None
|
||||
cos, sin = position_embeddings
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
||||
cos, sin = position_embeddings_context
|
||||
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def init_weights(self):
|
||||
torch.nn.init.zeros_(self.o_proj.weight)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.use_self_attn = use_self_attn
|
||||
|
||||
if self.use_self_attn:
|
||||
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(
|
||||
query_dim=model_dim,
|
||||
context_dim=model_dim,
|
||||
n_heads=num_heads,
|
||||
head_dim=model_dim//num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
query_dim=model_dim,
|
||||
context_dim=source_dim,
|
||||
n_heads=num_heads,
|
||||
head_dim=model_dim//num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
|
||||
if self.use_self_attn:
|
||||
normed = self.norm_self_attn(x)
|
||||
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
|
||||
x = x + attn_out
|
||||
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
||||
x = x + attn_out
|
||||
|
||||
x = x + self.mlp(self.norm_mlp(x))
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
torch.nn.init.zeros_(self.mlp[2].weight)
|
||||
self.cross_attn.init_weights()
|
||||
|
||||
|
||||
class LLMAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
source_dim=1024,
|
||||
target_dim=1024,
|
||||
model_dim=1024,
|
||||
num_layers=6,
|
||||
num_heads=16,
|
||||
use_self_attn=True,
|
||||
layer_norm=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
|
||||
if model_dim != target_dim:
|
||||
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.in_proj = nn.Identity()
|
||||
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
||||
if target_attention_mask is not None:
|
||||
target_attention_mask = target_attention_mask.to(torch.bool)
|
||||
if target_attention_mask.ndim == 2:
|
||||
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if source_attention_mask is not None:
|
||||
source_attention_mask = source_attention_mask.to(torch.bool)
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
||||
for block in self.blocks:
|
||||
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
||||
return self.norm(self.out_proj(x))
|
||||
|
||||
|
||||
class Anima(MiniTrainDIT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
else:
|
||||
return text_embeds
|
||||
@@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
|
||||
@@ -4,6 +4,7 @@ from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
@@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
@@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -3,7 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
@@ -102,13 +103,13 @@ UPSAMPLERS = {
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.load_device = comfy.model_management.vae_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.dtype = comfy.model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
@@ -117,5 +118,5 @@ class HunyuanVideo15SRModel():
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
|
||||
871
comfy/ldm/lightricks/av_model.py
Normal file
871
comfy/ldm/lightricks/av_model.py
Normal file
@@ -0,0 +1,871 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
"""Expand back to original tensor."""
|
||||
if self.patches_per_frame == 1:
|
||||
return self.data
|
||||
|
||||
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||
|
||||
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# No compression - compute directly
|
||||
if self.patches_per_frame == 1:
|
||||
num_tokens = self.data.shape[1]
|
||||
dim_per_param = self.feature_dim // num_ada_params
|
||||
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
# Compressed: compute on per-frame data then expand spatially
|
||||
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||
device=self.data.device, dtype=self.data.dtype
|
||||
)
|
||||
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||
|
||||
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||
return tuple(
|
||||
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||
for frame_val in frame_ada
|
||||
)
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
v_dim,
|
||||
a_dim,
|
||||
v_heads,
|
||||
a_heads,
|
||||
vd_head,
|
||||
ad_head,
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn1 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=v_context_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn2 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=a_context_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Video, K,V: Audio
|
||||
self.audio_to_video_attn = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
self.video_to_audio_attn = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=v_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.audio_ff = FeedForward(
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
if isinstance(timestep, CompressedTimestep):
|
||||
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||
).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
num_scale_shift_values: int = 4,
|
||||
):
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
|
||||
vx, ax = x
|
||||
run_ax = run_ax and ax.numel() > 0
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
# audio to video cross attention
|
||||
if run_a2v:
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
|
||||
class LTXAVModel(LTXVModel):
|
||||
"""LTXAV model for audio-video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
audio_in_channels=128,
|
||||
cross_attention_dim=4096,
|
||||
audio_cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
audio_attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
audio_num_attention_heads=32,
|
||||
caption_channels=3840,
|
||||
num_layers=48,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store audio-specific parameters
|
||||
self.audio_in_channels = audio_in_channels
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
self.audio_out_channels = audio_in_channels
|
||||
|
||||
# Audio-specific constants
|
||||
self.num_audio_channels = 8
|
||||
self.audio_frequency_bins = 16
|
||||
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXAV-specific components."""
|
||||
# Audio-specific projections
|
||||
self.audio_patchify_proj = self.operations.Linear(
|
||||
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicAVTransformerBlock(
|
||||
v_dim=self.inner_dim,
|
||||
a_dim=self.audio_inner_dim,
|
||||
v_heads=self.num_attention_heads,
|
||||
a_heads=self.audio_num_attention_heads,
|
||||
vd_head=self.attention_head_dim,
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXAV."""
|
||||
# Video output components
|
||||
super()._init_output_components(device, dtype)
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||
)
|
||||
self.audio_norm_out = self.operations.LayerNorm(
|
||||
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.audio_proj_out = self.operations.Linear(
|
||||
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||
)
|
||||
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||
|
||||
def separate_audio_and_video_latents(self, x, audio_length):
|
||||
"""Separate audio and video latents from combined input."""
|
||||
# vx = x[:, : self.in_channels]
|
||||
# ax = x[:, self.in_channels :]
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||
#
|
||||
# ax = ax.reshape(
|
||||
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||
# )
|
||||
|
||||
vx = x[0]
|
||||
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||
device=vx.device, dtype=vx.dtype
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||
if ax.numel() == 0:
|
||||
return vx
|
||||
else:
|
||||
return [vx, ax]
|
||||
"""Recombine audio and video latents for output."""
|
||||
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# # pad to f x h x w of the video latents
|
||||
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||
# if target_shape is None:
|
||||
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||
# else:
|
||||
# repetitions = target_shape[1] - vx.shape[1]
|
||||
# padded_len = repetitions * divisor
|
||||
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||
# return torch.cat([vx, ax], dim=1)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
timestep_flat = timestep_scaled.flatten()
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
v_pixel_coords = pixel_coords[0]
|
||||
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||
|
||||
a_latent_coords = pixel_coords[1]
|
||||
a_pe = self._precompute_freqs_cis(
|
||||
a_latent_coords,
|
||||
dim=self.audio_inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||
max_pos = max(
|
||||
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||
)
|
||||
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||
v_pixel_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||
a_latent_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context = context[0]
|
||||
a_context = context[1]
|
||||
v_timestep = timestep[0]
|
||||
a_timestep = timestep[1]
|
||||
v_pe, av_cross_video_freq_cis = pe[0]
|
||||
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||
|
||||
(
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(
|
||||
args["img"],
|
||||
v_context=args["v_context"],
|
||||
a_context=args["a_context"],
|
||||
attention_mask=args["attention_mask"],
|
||||
v_timestep=args["v_timestep"],
|
||||
a_timestep=args["a_timestep"],
|
||||
v_pe=args["v_pe"],
|
||||
a_pe=args["a_pe"],
|
||||
v_cross_pe=args["v_cross_pe"],
|
||||
a_cross_pe=args["a_cross_pe"],
|
||||
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": (vx, ax),
|
||||
"v_context": v_context,
|
||||
"a_context": a_context,
|
||||
"attention_mask": attention_mask,
|
||||
"v_timestep": v_timestep,
|
||||
"a_timestep": a_timestep,
|
||||
"v_pe": v_pe,
|
||||
"a_pe": a_pe,
|
||||
"v_cross_pe": av_cross_video_freq_cis,
|
||||
"a_cross_pe": av_cross_audio_freq_cis,
|
||||
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
vx, ax = out["img"]
|
||||
else:
|
||||
vx, ax = block(
|
||||
(vx, ax),
|
||||
v_context=v_context,
|
||||
a_context=a_context,
|
||||
attention_mask=attention_mask,
|
||||
v_timestep=v_timestep,
|
||||
a_timestep=a_timestep,
|
||||
v_pe=v_pe,
|
||||
a_pe=a_pe,
|
||||
v_cross_pe=av_cross_video_freq_cis,
|
||||
a_cross_pe=av_cross_audio_freq_cis,
|
||||
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
a_scale_shift_values = (
|
||||
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||
+ a_embedded_timestep[:, :, None]
|
||||
)
|
||||
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||
|
||||
ax = self.audio_norm_out(ax)
|
||||
ax = ax * (1 + a_scale) + a_shift
|
||||
ax = self.audio_proj_out(ax)
|
||||
|
||||
# Unpatchify audio
|
||||
ax = self.a_patchifier.unpatchify(
|
||||
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||
)
|
||||
|
||||
# Recombine audio and video
|
||||
original_shape = kwargs.get("av_orig_shape")
|
||||
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask=None,
|
||||
frame_rate=25,
|
||||
transformer_options={},
|
||||
keyframe_idxs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass for LTXAV model.
|
||||
|
||||
Args:
|
||||
x: Combined audio-video input tensor
|
||||
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments including audio_length
|
||||
|
||||
Returns:
|
||||
Combined audio-video output tensor
|
||||
"""
|
||||
# Handle timestep format
|
||||
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||
v_timestep, a_timestep = timestep
|
||||
kwargs["a_timestep"] = a_timestep
|
||||
timestep = v_timestep
|
||||
else:
|
||||
kwargs["a_timestep"] = timestep
|
||||
|
||||
# Call parent forward method
|
||||
return super().forward(
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask,
|
||||
frame_rate,
|
||||
transformer_options,
|
||||
keyframe_idxs,
|
||||
**kwargs,
|
||||
)
|
||||
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import torch
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
generate_freq_grid_np,
|
||||
interleaved_freqs_cis,
|
||||
split_freqs_cis,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BasicTransformerBlock1D(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||
qk_norm (`str`, *optional*, defaults to None):
|
||||
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
glu=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
num_attention_heads=30,
|
||||
num_layers=2,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
split_rope=False,
|
||||
double_precision_rope=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_rope = split_rope
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock1D(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||
for i in range(1)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
def precompute_freqs(self, indices_grid, spacing):
|
||||
source_dtype = indices_grid.dtype
|
||||
dtype = (
|
||||
torch.float32
|
||||
if source_dtype in (torch.bfloat16, torch.float16)
|
||||
else source_dtype
|
||||
)
|
||||
|
||||
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||
indices = (
|
||||
generate_freq_grid_np(
|
||||
self.positional_embedding_theta,
|
||||
indices_grid.shape[1],
|
||||
self.inner_dim,
|
||||
)
|
||||
if self.double_precision_rope
|
||||
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||
).to(device=fractional_positions.device)
|
||||
|
||||
if spacing == "exp_2":
|
||||
freqs = (
|
||||
(indices * fractional_positions.unsqueeze(-1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
else:
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
def generate_freq_grid(self, spacing, dtype, device):
|
||||
dim = self.inner_dim
|
||||
theta = self.positional_embedding_theta
|
||||
n_pos_dims = 1
|
||||
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
if spacing == "exp":
|
||||
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||
indices = indices.to(dtype=dtype, device=device)
|
||||
elif spacing == "exp_2":
|
||||
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||
indices = indices.to(dtype=dtype)
|
||||
elif spacing == "linear":
|
||||
indices = torch.linspace(
|
||||
start, end, dim // n_elem, device=device, dtype=dtype
|
||||
)
|
||||
elif spacing == "sqrt":
|
||||
indices = torch.linspace(
|
||||
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||
).sqrt()
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
if self.split_rope:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(
|
||||
freqs, pad_size, self.num_attention_heads
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
|
||||
if self.num_learnable_registers:
|
||||
num_registers_duplications = math.ceil(
|
||||
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||
)
|
||||
learnable_registers = torch.tile(
|
||||
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||
)
|
||||
|
||||
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
|
||||
indices_grid = torch.arange(
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
# if self.output_scale is not None:
|
||||
# hidden_states = hidden_states / self.output_scale
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(
|
||||
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||
)
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class PixelShuffleND(nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
|
||||
|
||||
class BlurDownsample(nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int):
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert stride >= 1 and isinstance(stride, int)
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
|
||||
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||
# x2d: (B, C, H, W)
|
||||
B, C, H, W = x2d.shape
|
||||
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||
x2d = F.conv2d(
|
||||
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||
)
|
||||
return x2d
|
||||
|
||||
if self.dims == 2:
|
||||
return _apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = _apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||
):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = nn.GroupNorm(32, channels)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class LatentUpsampler(nn.Module):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 512,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = nn.SiLU()
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=self.spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either spatial_upsample or temporal_upsample must be True"
|
||||
)
|
||||
|
||||
self.post_upsample_res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
x = x[:, :, 1:, :, :]
|
||||
else:
|
||||
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(
|
||||
in_channels=config.get("in_channels", 4),
|
||||
mid_channels=config.get("mid_channels", 128),
|
||||
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||
dims=config.get("dims", 2),
|
||||
spatial_upsample=config.get("spatial_upsample", True),
|
||||
temporal_upsample=config.get("temporal_upsample", False),
|
||||
spatial_scale=config.get("spatial_scale", 2.0),
|
||||
rational_resampler=config.get("rational_resampler", False),
|
||||
)
|
||||
|
||||
def config(self):
|
||||
return {
|
||||
"_class_name": "LatentUpsampler",
|
||||
"in_channels": self.in_channels,
|
||||
"mid_channels": self.mid_channels,
|
||||
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||
"dims": self.dims,
|
||||
"spatial_upsample": self.spatial_upsample,
|
||||
"temporal_upsample": self.temporal_upsample,
|
||||
"spatial_scale": self.spatial_scale,
|
||||
"rational_resampler": self.rational_resampler,
|
||||
}
|
||||
@@ -1,13 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
class LTXRopeType(str, Enum):
|
||||
INTERLEAVED = "interleaved"
|
||||
SPLIT = "split"
|
||||
|
||||
KEY = "rope_type"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.INTERLEAVED
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
class LTXFrequenciesPrecision(str, Enum):
|
||||
FLOAT32 = "float32"
|
||||
FLOAT64 = "float64"
|
||||
|
||||
KEY = "frequencies_precision"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.FLOAT32
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@@ -39,9 +73,7 @@ def get_timestep_embedding(
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
@@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(
|
||||
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
@@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
size_emb_dim,
|
||||
use_additional_conditions: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
@@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
embedding_dim,
|
||||
size_emb_dim=embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
@@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(
|
||||
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
@@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||
return (
|
||||
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
if split_pe else
|
||||
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
)
|
||||
|
||||
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
needs_reshape = False
|
||||
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||
B, H, T, _ = cos.shape
|
||||
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||
needs_reshape = True
|
||||
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||
first_half_input = split_input[..., :1, :]
|
||||
second_half_input = split_input[..., 1:, :]
|
||||
output = split_input * cos.unsqueeze(-2)
|
||||
first_half_output = output[..., :1, :]
|
||||
second_half_output = output[..., 1:, :]
|
||||
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||
output = rearrange(output, "... d r -> ... (d r)")
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
@@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
@@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||
axis=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
@functools.lru_cache(maxsize=5)
|
||||
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(
|
||||
_log_base(start, theta),
|
||||
_log_base(end, theta),
|
||||
inner_dim // n_elem,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||
|
||||
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
inner_dim // n_elem,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=torch.float32)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||
if use_middle_indices_grid:
|
||||
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
indices = indices.to(device=fractional_positions.device)
|
||||
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
def interleaved_freqs_cis(freqs, pad_size):
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
return freqs_cis
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
B , T, half_HD = cos_freq.shape
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for LTX models (Lightricks Transformer models).
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
This class defines the common interface and shared functionality for all LTX models,
|
||||
including LTXV (video) and LTXAV (audio-video) variants.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
cross_attention_dim: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
caption_channels: int,
|
||||
num_layers: int,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||
causal_temporal_positioning: bool = False,
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.caption_channels = caption_channels
|
||||
self.num_layers = num_layers
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||
self.freq_grid_generator = (
|
||||
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||
else generate_freq_grid_pytorch
|
||||
)
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels
|
||||
|
||||
# Initialize common components
|
||||
self._init_common_components(device, dtype)
|
||||
|
||||
# Initialize model-specific components
|
||||
self._init_model_components(device, dtype, **kwargs)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||
|
||||
# Initialize output components
|
||||
self._init_output_components(device, dtype)
|
||||
|
||||
def _init_common_components(self, device, dtype):
|
||||
"""Initialize components common to all LTX models
|
||||
- patchify_proj: Linear projection for patchifying input
|
||||
- adaln_single: AdaLN layer for timestep embedding
|
||||
- caption_projection: Linear projection for caption embedding
|
||||
"""
|
||||
self.patchify_proj = self.operations.Linear(
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
self,
|
||||
indices_grid,
|
||||
dim,
|
||||
out_dtype,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=False,
|
||||
num_attention_heads=32,
|
||||
):
|
||||
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
if split_mode:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
"""Prepare positional embeddings."""
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
pe = self._precompute_freqs_cis(
|
||||
fractional_coords,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
return pe
|
||||
|
||||
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||
"""Prepare attention mask."""
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
) * torch.finfo(x_dtype).max
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(
|
||||
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||
),
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Internal forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
input_dtype = x[0].dtype
|
||||
batch_size = x[0].shape[0]
|
||||
else:
|
||||
input_dtype = x.dtype
|
||||
batch_size = x.shape[0]
|
||||
# Process input
|
||||
merged_args = {**transformer_options, **kwargs}
|
||||
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
)
|
||||
|
||||
# Process output
|
||||
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||
return x
|
||||
|
||||
|
||||
class LTXVModel(LTXBaseModel):
|
||||
"""LTXV model for video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXV."""
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
self.norm_out = self.operations.LayerNorm(
|
||||
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXV."""
|
||||
additional_args = {"orig_shape": list(x.shape)}
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
@@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
additional_args.update({"grid_mask": grid_mask})
|
||||
x = x[:, grid_mask, :]
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
@@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
return x
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output for LTXV."""
|
||||
# Apply scale-shift modulation
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
full_x[:, grid_mask, :] = x
|
||||
x = full_x
|
||||
# Unpatchify to restore original dimensions
|
||||
orig_shape = kwargs["orig_shape"]
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
|
||||
@@ -21,20 +21,23 @@ def latent_to_pixel_coords(
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
shape = [1] * latent_coords.ndim
|
||||
shape[1] = -1
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
def __init__(self, patch_size: int, start_end: bool=False):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
self.start_end = start_end
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
@@ -71,11 +74,23 @@ class Patchifier(ABC):
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||
|
||||
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_start = rearrange(
|
||||
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
if self.start_end:
|
||||
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_end = rearrange(
|
||||
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
|
||||
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||
else:
|
||||
latent_coords = latent_sample_coords_start
|
||||
return latent_coords
|
||||
|
||||
|
||||
@@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
|
||||
class AudioPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int,
|
||||
sample_rate=16000,
|
||||
hop_length=160,
|
||||
audio_latent_downsample_factor=4,
|
||||
is_causal=True,
|
||||
start_end=False,
|
||||
shift = 0
|
||||
):
|
||||
super().__init__(patch_size, start_end=start_end)
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self.shift = shift
|
||||
|
||||
def copy_with_shift(self, shift):
|
||||
return AudioPatchifier(
|
||||
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||
self.is_causal, self.start_end, shift
|
||||
)
|
||||
|
||||
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||
if self.is_causal:
|
||||
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# audio_latents: (batch, channels, time, freq)
|
||||
b, _, t, _ = audio_latents.shape
|
||||
audio_latents = rearrange(
|
||||
audio_latents,
|
||||
"b c t f -> b t (c f)",
|
||||
)
|
||||
|
||||
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
if self.start_end:
|
||||
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||
else:
|
||||
audio_latents_timings = audio_latents_start_timings
|
||||
return audio_latents, audio_latents_timings
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||
# audio_latents: (batch, time, freq * channels)
|
||||
audio_latents = rearrange(
|
||||
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||
)
|
||||
return audio_latents
|
||||
|
||||
279
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
279
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils as utils
|
||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioVAEComponentConfig:
|
||||
"""Container for model component configuration extracted from metadata."""
|
||||
|
||||
autoencoder: dict
|
||||
vocoder: dict
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||
|
||||
raw_config = metadata["config"]
|
||||
if isinstance(raw_config, str):
|
||||
parsed_config = json.loads(raw_config)
|
||||
else:
|
||||
parsed_config = raw_config
|
||||
|
||||
audio_config = parsed_config.get("audio_vae")
|
||||
vocoder_config = parsed_config.get("vocoder")
|
||||
|
||||
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||
|
||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||
|
||||
|
||||
class ModelDeviceManager:
|
||||
"""Manages device placement and GPU residency for the composed model."""
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||
|
||||
def ensure_model_loaded(self) -> None:
|
||||
comfy.model_management.free_memory(
|
||||
self.patcher.model_size(),
|
||||
self.patcher.load_device,
|
||||
)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
|
||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(self.patcher.load_device)
|
||||
|
||||
@property
|
||||
def load_device(self):
|
||||
return self.patcher.load_device
|
||||
|
||||
|
||||
class AudioLatentNormalizer:
|
||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||
|
||||
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||
self.patchifier = patchfier
|
||||
self.statistics = statistics_processor
|
||||
|
||||
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
normalized = self.statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||
|
||||
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
denormalized = self.statistics.un_normalize(patched)
|
||||
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||
|
||||
|
||||
class AudioPreprocessor:
|
||||
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||
|
||||
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.mel_bins = mel_bins
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.n_fft = n_fft
|
||||
|
||||
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||
if source_rate == self.target_sample_rate:
|
||||
return waveform
|
||||
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||
|
||||
def waveform_to_mel(
|
||||
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||
) -> torch.Tensor:
|
||||
waveform = self.resample(waveform, waveform_sample_rate)
|
||||
|
||||
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.n_fft,
|
||||
hop_length=self.mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=self.target_sample_rate / 2.0,
|
||||
n_mels=self.mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
|
||||
mel = mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class AudioVAE(torch.nn.Module):
|
||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||
|
||||
def __init__(self, state_dict: dict, metadata: dict):
|
||||
super().__init__()
|
||||
|
||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||
|
||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=autoencoder_config["sampling_rate"],
|
||||
hop_length=autoencoder_config["mel_hop_length"],
|
||||
is_causal=autoencoder_config["is_causal"],
|
||||
),
|
||||
self.autoencoder.per_channel_statistics,
|
||||
)
|
||||
|
||||
self.preprocessor = AudioPreprocessor(
|
||||
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||
mel_bins=autoencoder_config["mel_bins"],
|
||||
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
self.device_manager = ModelDeviceManager(self)
|
||||
|
||||
def encode(self, audio: dict) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
waveform = audio["waveform"]
|
||||
waveform_sample_rate = audio["sample_rate"]
|
||||
input_device = waveform.device
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
if waveform.shape[1] == 1:
|
||||
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
)
|
||||
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
posterior = DiagonalGaussianDistribution(latents)
|
||||
latent_mode = posterior.mode()
|
||||
|
||||
normalized = self.normalizer.normalize(latent_mode)
|
||||
return normalized.to(input_device)
|
||||
|
||||
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
latents = self.device_manager.move_to_load_device(latents)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||
|
||||
waveform = self.run_vocoder(mel_spec)
|
||||
return self.device_manager.move_to_load_device(waveform)
|
||||
|
||||
def target_shape_from_latents(self, latents_shape):
|
||||
batch, _, time, _ = latents_shape
|
||||
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||
return (
|
||||
batch,
|
||||
self.autoencoder.decoder.out_ch,
|
||||
target_length,
|
||||
self.autoencoder.mel_bins,
|
||||
)
|
||||
|
||||
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||
|
||||
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
audio_channels = self.autoencoder.decoder.out_ch
|
||||
vocoder_input = mel_spec.transpose(2, 3)
|
||||
|
||||
if audio_channels == 1:
|
||||
vocoder_input = vocoder_input.squeeze(1)
|
||||
elif audio_channels != 2:
|
||||
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||
|
||||
return self.vocoder(vocoder_input)
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return int(self.autoencoder.sampling_rate)
|
||||
|
||||
@property
|
||||
def mel_hop_length(self) -> int:
|
||||
return int(self.autoencoder.mel_hop_length)
|
||||
|
||||
@property
|
||||
def mel_bins(self) -> int:
|
||||
return int(self.autoencoder.mel_bins)
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return int(self.autoencoder.decoder.z_channels)
|
||||
|
||||
@property
|
||||
def latent_frequency_bins(self) -> int:
|
||||
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||
|
||||
@property
|
||||
def latents_per_second(self) -> float:
|
||||
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
@property
|
||||
def output_sample_rate(self) -> int:
|
||||
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||
if output_rate is not None:
|
||||
return int(output_rate)
|
||||
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||
if upsample_factor is None:
|
||||
raise AttributeError(
|
||||
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||
)
|
||||
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.device_manager.patcher.model_size()
|
||||
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
@@ -0,0 +1,909 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from .pixel_norm import PixelNorm
|
||||
import comfy.ops
|
||||
import logging
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class StringConvertibleEnum(Enum):
|
||||
"""
|
||||
Base enum class that provides string-to-enum conversion functionality.
|
||||
|
||||
This mixin adds a str_to_enum() class method that handles conversion from
|
||||
strings, None, or existing enum instances with case-insensitive matching.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def str_to_enum(cls, value):
|
||||
"""
|
||||
Convert a string, enum instance, or None to the appropriate enum member.
|
||||
|
||||
Args:
|
||||
value: Can be an enum instance of this class, a string, or None
|
||||
|
||||
Returns:
|
||||
Enum member of this class
|
||||
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to a valid enum member
|
||||
"""
|
||||
# Already an enum instance of this class
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
|
||||
# None maps to NONE member if it exists
|
||||
if value is None:
|
||||
if hasattr(cls, "NONE"):
|
||||
return cls.NONE
|
||||
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||
|
||||
# String conversion (case-insensitive)
|
||||
if isinstance(value, str):
|
||||
value_lower = value.lower()
|
||||
|
||||
# Try to match against enum values
|
||||
for member in cls:
|
||||
# Handle members with None values
|
||||
if member.value is None:
|
||||
if value_lower == "none":
|
||||
return member
|
||||
# Handle members with string values
|
||||
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||
return member
|
||||
|
||||
# Build helpful error message with valid values
|
||||
valid_values = []
|
||||
for member in cls:
|
||||
if member.value is None:
|
||||
valid_values.append("none")
|
||||
elif isinstance(member.value, str):
|
||||
valid_values.append(member.value)
|
||||
|
||||
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||
|
||||
raise ValueError(
|
||||
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||
f"Expected string, None, or {cls.__name__} instance."
|
||||
)
|
||||
|
||||
|
||||
class AttentionType(StringConvertibleEnum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CausalityAxis(StringConvertibleEnum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||
if normtype == "group":
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif normtype == "pixel":
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution.
|
||||
|
||||
This layer ensures that the output at time `t` only depends on inputs
|
||||
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||
to the time dimension (width) before the convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
# Ensure kernel_size and dilation are tuples
|
||||
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||
dilation = nn.modules.utils._pair(dilation)
|
||||
|
||||
# Calculate padding dimensions
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.HEIGHT:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
# The internal convolution layer uses no padding, as we handle it manually
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Apply causal padding before convolution
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: Optional[CausalityAxis] = None,
|
||||
):
|
||||
"""
|
||||
Create a 2D convolution layer that can be either causal or non-causal.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels
|
||||
kernel_size: Size of the convolution kernel
|
||||
stride: Convolution stride
|
||||
padding: Padding (if None, will be calculated based on causal flag)
|
||||
dilation: Dilation rate
|
||||
groups: Number of groups for grouped convolution
|
||||
bias: Whether to use bias
|
||||
causality_axis: Dimension along which to apply causality.
|
||||
|
||||
Returns:
|
||||
Either a regular Conv2d or CausalConv2d layer
|
||||
"""
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int):
|
||||
padding = kernel_size // 2
|
||||
else:
|
||||
padding = tuple(k // 2 for k in kernel_size)
|
||||
return ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||
# So the output elements rely on the following windows:
|
||||
# 0: [-,-,0]
|
||||
# 1: [-,0,0]
|
||||
# 2: [0,0,1]
|
||||
# 3: [0,1,1]
|
||||
# 4: [1,1,2]
|
||||
# 5: [1,2,2]
|
||||
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||
# while all other elements rely on two elements in the input.
|
||||
# So we can drop the first element to undo the padding (rather than the last element).
|
||||
# This is a no-op for non-causal convolutions.
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pass # x remains unchanged
|
||||
case CausalityAxis.HEIGHT:
|
||||
x = x[:, :, 1:, :]
|
||||
case CausalityAxis.WIDTH:
|
||||
x = x[:, :, :, 1:]
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pass # x remains unchanged
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer that can use either a strided convolution
|
||||
or average pooling. Supports standard and causal padding for the
|
||||
convolutional mode.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pad = (0, 1, 0, 1)
|
||||
case CausalityAxis.WIDTH:
|
||||
pad = (2, 0, 0, 1)
|
||||
case CausalityAxis.HEIGHT:
|
||||
pad = (0, 1, 2, 0)
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pad = (1, 0, 0, 1)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
norm_type="group",
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, norm_type="group"):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w).contiguous()
|
||||
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w).contiguous()
|
||||
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||
# Convert string to enum if needed
|
||||
attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
if attn_type != AttentionType.NONE:
|
||||
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||
else:
|
||||
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||
|
||||
match attn_type:
|
||||
case AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
case AttentionType.NONE:
|
||||
return nn.Identity(in_channels)
|
||||
case AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.z_channels = z_channels
|
||||
self.double_z = double_z
|
||||
self.norm_type = norm_type
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = make_conv2d(
|
||||
in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the encoder.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||
|
||||
Returns:
|
||||
Encoded latent representation
|
||||
"""
|
||||
feature_maps = [self.conv_in(x)]
|
||||
|
||||
# Process each resolution level (from high to low resolution)
|
||||
for resolution_level in range(self.num_resolutions):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
# Apply ResNet block with optional timestep embedding
|
||||
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||
|
||||
# Apply attention if configured for this resolution level
|
||||
if len(self.down[resolution_level].attn) > 0:
|
||||
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||
|
||||
# Store processed features
|
||||
feature_maps.append(current_features)
|
||||
|
||||
# Downsample spatial dimensions (except at the final resolution level)
|
||||
if resolution_level != self.num_resolutions - 1:
|
||||
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||
feature_maps.append(downsampled_features)
|
||||
|
||||
# === MIDDLE PROCESSING PHASE ===
|
||||
# Take the lowest resolution features for middle processing
|
||||
bottleneck_features = feature_maps[-1]
|
||||
|
||||
# Apply first middle ResNet block
|
||||
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||
|
||||
# Apply middle attention block
|
||||
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||
|
||||
# Apply second middle ResNet block
|
||||
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||
|
||||
# === OUTPUT PHASE ===
|
||||
# Normalize the bottleneck features
|
||||
output_features = self.norm_out(bottleneck_features)
|
||||
|
||||
# Apply non-linearity (SiLU activation)
|
||||
output_features = self.non_linearity(output_features)
|
||||
|
||||
# Final convolution to produce latent representation
|
||||
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||
return self.conv_out(output_features)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = out_ch
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.norm_type = norm_type
|
||||
self.z_channels = z_channels
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||
"""
|
||||
Adjust output shape to match target dimensions for variable-length audio.
|
||||
|
||||
This function handles the common case where decoded audio spectrograms need to be
|
||||
resized to match a specific target shape.
|
||||
|
||||
Args:
|
||||
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||
|
||||
Returns:
|
||||
Tensor adjusted to match target_shape exactly
|
||||
"""
|
||||
# Current output shape: (batch, channels, time, frequency)
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
_, target_channels, target_time, target_freq = target_shape
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
# Step 3: Apply padding if needed
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||
0,
|
||||
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
# Step 4: Final safety crop to ensure exact target shape
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"ch": self.ch,
|
||||
"out_ch": self.out_ch,
|
||||
"ch_mult": self.ch_mult,
|
||||
"num_res_blocks": self.num_res_blocks,
|
||||
"in_channels": self.in_channels,
|
||||
"resolution": self.resolution,
|
||||
"z_channels": self.z_channels,
|
||||
}
|
||||
|
||||
def forward(self, latent_features, target_shape=None):
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
|
||||
Args:
|
||||
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||
If provided, output will be cropped/padded to match this shape
|
||||
|
||||
Returns:
|
||||
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||
"""
|
||||
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||
|
||||
# Transform latent features to decoder's internal feature dimension
|
||||
hidden_features = self.conv_in(latent_features)
|
||||
|
||||
# Middle processing
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
# Upsampling
|
||||
# Progressively increase spatial resolution from lowest to highest
|
||||
for resolution_level in reversed(range(self.num_resolutions)):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_index in range(self.num_res_blocks + 1):
|
||||
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||
|
||||
if len(self.up[resolution_level].attn) > 0:
|
||||
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||
|
||||
if resolution_level != 0:
|
||||
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||
|
||||
# Output
|
||||
if self.give_pre_end:
|
||||
# Return intermediate features before final processing (for debugging/analysis)
|
||||
decoded_output = hidden_features
|
||||
else:
|
||||
# Standard output path: normalize, activate, and convert to output channels
|
||||
# Final normalization layer
|
||||
hidden_features = self.norm_out(hidden_features)
|
||||
|
||||
# Apply SiLU (Swish) activation function
|
||||
hidden_features = self.non_linearity(hidden_features)
|
||||
|
||||
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||
decoded_output = self.conv_out(hidden_features)
|
||||
|
||||
# Optional tanh activation to bound output values to [-1, 1] range
|
||||
if self.tanh_out:
|
||||
decoded_output = torch.tanh(decoded_output)
|
||||
|
||||
# Adjust shape for audio data
|
||||
if target_shape is not None:
|
||||
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||
|
||||
|
||||
class CausalAudioAutoencoder(nn.Module):
|
||||
def __init__(self, config=None):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
self.encoder = Encoder(**encoder_config)
|
||||
self.decoder = Decoder(**decoder_config)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"mel_bins": self.mel_bins,
|
||||
"mel_hop_length": self.mel_hop_length,
|
||||
"n_fft": self.n_fft,
|
||||
"causality_axis": self.causality_axis.value,
|
||||
"is_causal": self.is_causal,
|
||||
}
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x, target_shape=None):
|
||||
return self.decoder(x, target_shape=target_shape)
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import threading
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
tid = threading.get_ident()
|
||||
|
||||
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
|
||||
if cached is None:
|
||||
padding_length = self.time_kernel_size - 1
|
||||
if not causal:
|
||||
padding_length = padding_length // 2
|
||||
if x.shape[2] == 0:
|
||||
return x
|
||||
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
@@ -6,12 +7,35 @@ import math
|
||||
from einops import rearrange
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
current = m.temporal_cache_state.get(tid, (None, False))
|
||||
m.temporal_cache_state[tid] = (current[0], True)
|
||||
|
||||
def split2(tensor, split_point, dim=2):
|
||||
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||
|
||||
def add_exchange_cache(dest, cache_in, new_input, dim=2):
|
||||
if dest is not None:
|
||||
if cache_in is not None:
|
||||
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
|
||||
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
|
||||
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
|
||||
lead_in_dest.add_(lead_in_source)
|
||||
body, new_input = split2(new_input, dest.shape[dim], dim)
|
||||
dest.add_(body)
|
||||
return torch_cat_if_needed([cache_in, new_input], dim=dim)
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
@@ -205,7 +229,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
@@ -254,6 +278,22 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
tid = threading.get_ident()
|
||||
for _, module in self.named_modules():
|
||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||
# we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -341,18 +381,6 @@ class Decoder(nn.Module):
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
@@ -428,8 +456,9 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
@@ -437,6 +466,7 @@ class Decoder(nn.Module):
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
mark_conv3d_ended(self.conv_in)
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -445,24 +475,12 @@ class Decoder(nn.Module):
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
scaled_timestep = None
|
||||
timestep_shift_scale = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
@@ -483,16 +501,62 @@ class Decoder(nn.Module):
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
output = []
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
for _, module in self.named_modules():
|
||||
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
||||
#we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
@@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
tid = threading.get_ident()
|
||||
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
|
||||
y = self.conv(x, causal=causal)
|
||||
y = rearrange(
|
||||
y,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
|
||||
y = y[:, :, 1:, :, :]
|
||||
drop_first_conv = False
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
@@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
drop_first_res = False
|
||||
|
||||
if y.shape[2] == 0:
|
||||
y = None
|
||||
|
||||
cached = add_exchange_cache(y, cached, x_in, dim=2)
|
||||
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
|
||||
|
||||
else:
|
||||
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
|
||||
|
||||
return y
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
@@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
tid = threading.get_ident()
|
||||
cached = self.temporal_cache_state.get(tid, None)
|
||||
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
|
||||
self.temporal_cache_state[tid] = cached
|
||||
|
||||
return output_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
|
||||
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(Vocoder, self).__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
x: Input spectrogram tensor. Can be:
|
||||
- 3D: (batch_size, channels, time_steps) for mono
|
||||
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||
|
||||
Returns:
|
||||
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
if x.dim() == 4: # stereo
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
@@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.additional_in_dim = 0
|
||||
self.control_in_dim = 16
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = 6
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
@@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
@@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
|
||||
@@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def modulate(x, scale, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
else:
|
||||
scale = (1 + scale.unsqueeze(1))
|
||||
actual_batch = scale.size(0) // 2
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= scale[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= scale[:actual_batch]
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(gate, x, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return gate * x
|
||||
else:
|
||||
actual_batch = gate.size(0) // 2
|
||||
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= gate[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= gate[:actual_batch]
|
||||
return x
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
@@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
@@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
|
||||
clamp_fp16(self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
))
|
||||
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
@@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
def forward(self, x, c, timestep_zero_index=None):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def pad_zimage(feats, pad_token, pad_tokens_multiple):
|
||||
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
|
||||
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
|
||||
|
||||
|
||||
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = start_t
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
return x_pos_ids
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
@@ -378,10 +446,12 @@ class NextDiT(nn.Module):
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
siglip_feat_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
@@ -491,7 +561,43 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
siglip_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
@@ -530,61 +636,168 @@ class NextDiT(nn.Module):
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
|
||||
if cap_feats is not None:
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats_len = cap_feats.shape[1]
|
||||
if self.pad_tokens_multiple is not None:
|
||||
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
|
||||
else:
|
||||
cap_feats_len = 0
|
||||
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
|
||||
embeds = (cap_feats,)
|
||||
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len
|
||||
|
||||
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
|
||||
bsz = 1
|
||||
pH = pW = self.patch_size
|
||||
device = x.device
|
||||
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
|
||||
|
||||
if (not omni) or self.siglip_embedder is None:
|
||||
cap_feats_len = embeds[0].shape[1] + offset
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
cap_feats_len += offset
|
||||
if siglip_feats is not None:
|
||||
b, h, w, c = siglip_feats.shape
|
||||
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
|
||||
siglip_feats = self.siglip_embedder(siglip_feats)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
|
||||
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
|
||||
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
|
||||
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
|
||||
else:
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
|
||||
if siglip_feats is None:
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
embeds += (siglip_feats,)
|
||||
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
embeds += (x,)
|
||||
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
|
||||
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = x.shape[0]
|
||||
cap_mask = None # TODO?
|
||||
main_siglip = None
|
||||
orig_x = x
|
||||
|
||||
embeds = ([], [], [])
|
||||
freqs_cis = ([], [], [])
|
||||
leftover_cap = []
|
||||
|
||||
start_t = 0
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
for i, ref in enumerate(ref_latents):
|
||||
if i < len(ref_contexts):
|
||||
ref_con = ref_contexts[i]
|
||||
else:
|
||||
ref_con = None
|
||||
if i < len(siglip_feats):
|
||||
sig_feat = siglip_feats[i]
|
||||
else:
|
||||
sig_feat = None
|
||||
|
||||
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
leftover_cap = ref_contexts[len(ref_latents):]
|
||||
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
img_sizes = [(H, W)] * bsz
|
||||
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
img_len = out[0][-1].shape[1]
|
||||
cap_len = out[0][0].shape[1]
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
e = comfy.utils.repeat_to_batch_size(e, bsz)
|
||||
embeds[i].append(e)
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
|
||||
for cap in leftover_cap:
|
||||
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
|
||||
cap_len += out[0][0].shape[1]
|
||||
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
|
||||
freqs_cis[0].append(out[1][0])
|
||||
start_t += out[2]
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
cap_feats = torch.cat(embeds[0], dim=1)
|
||||
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
feats = (cap_feats,)
|
||||
fc = (cap_freqs_cis,)
|
||||
|
||||
if omni and len(embeds[1]) > 0:
|
||||
siglip_mask = None
|
||||
siglip_feats_combined = torch.cat(embeds[1], dim=1)
|
||||
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
|
||||
if self.siglip_refiner is not None:
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
|
||||
feats += (siglip_feats_combined,)
|
||||
fc += (siglip_feats_freqs_cis,)
|
||||
|
||||
padded_img_mask = None
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
x = torch.cat(embeds[-1], dim=1)
|
||||
fc_x = torch.cat(freqs_cis[-1], dim=1)
|
||||
if omni:
|
||||
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
|
||||
else:
|
||||
timestep_zero_index = None
|
||||
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat(feats + (x,), dim=1)
|
||||
if timestep_zero_index is not None:
|
||||
ind = padded_full_embed.shape[1] - x.shape[1]
|
||||
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
|
||||
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -594,7 +807,11 @@ class NextDiT(nn.Module):
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@@ -609,34 +826,35 @@ class NextDiT(nn.Module):
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -img
|
||||
|
||||
|
||||
@@ -30,6 +30,13 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
@@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if (q.device.type != "cuda" or
|
||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||
mask is not None):
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
B, H, L, D = q.shape
|
||||
if H != heads:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
B, N, inner_dim = q.shape
|
||||
if inner_dim % heads != 0:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
dim_head = inner_dim // heads
|
||||
|
||||
if dim_head >= 256 or N <= 1024:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||
except Exception as e:
|
||||
exception_fallback = True
|
||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||
|
||||
if exception_fallback:
|
||||
if not skip_reshape:
|
||||
del q_s, k_s, v_s
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
if not skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
pass
|
||||
else:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
@@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
|
||||
@@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
elif len(xl) == 1:
|
||||
return xl[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
@@ -394,7 +397,8 @@ class Model(nn.Module):
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@@ -548,7 +552,8 @@ class Encoder(nn.Module):
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
||||
@@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
||||
@@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
@@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
||||
operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if self.use_additional_t_cond:
|
||||
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
@@ -160,8 +170,14 @@ class Attention(nn.Module):
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
@@ -218,9 +234,24 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
||||
if timestep_zero_index is not None:
|
||||
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
||||
else:
|
||||
return torch.addcmul(y, gate, x)
|
||||
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
if timestep_zero_index is not None:
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
||||
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
||||
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
||||
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
||||
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
||||
else:
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -229,14 +260,19 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
@@ -251,15 +287,15 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
||||
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
@@ -300,10 +336,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
default_ref_method="index",
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
use_additional_t_cond=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -314,12 +351,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.default_ref_method = default_ref_method
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
use_additional_t_cond=use_additional_t_cond,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
@@ -341,6 +380,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
if self.default_ref_method == "index_timestep_zero":
|
||||
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
@@ -350,27 +392,33 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
t_len = t
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
if t_len > 1:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||
else:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@@ -378,8 +426,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
additional_t_cond=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
@@ -388,19 +436,30 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
|
||||
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||
negative_ref_method = ref_method == "negative_index"
|
||||
timestep_zero = ref_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif negative_ref_method:
|
||||
index -= 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
@@ -415,6 +474,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
@@ -426,14 +489,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
@@ -446,7 +502,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
@@ -458,6 +514,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep_zero_index=timestep_zero_index,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
@@ -474,9 +531,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
||||
@@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
||||
@@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
def qkv_fn_q(x):
|
||||
@@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "attn1_patch" in patches:
|
||||
for p in patches["attn1_patch"]:
|
||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
@@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_patch" in patches:
|
||||
for p in patches["attn2_patch"]:
|
||||
x = p({"x": x, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@@ -488,7 +501,7 @@ class WanModel(torch.nn.Module):
|
||||
self.blocks = nn.ModuleList([
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
@@ -541,6 +554,7 @@ class WanModel(torch.nn.Module):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@@ -568,7 +582,10 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -735,6 +752,7 @@ class VaceWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@@ -763,7 +781,10 @@ class VaceWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -862,7 +883,10 @@ class CameraWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -1326,16 +1350,19 @@ class WanModel_S2V(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
||||
if audio_emb is not None:
|
||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||
# head
|
||||
@@ -1574,7 +1601,10 @@ class HumoWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
500
comfy/ldm/wan/model_multitalk.py
Normal file
500
comfy/ldm/wan/model_multitalk.py
Normal file
@@ -0,0 +1,500 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||
visual_q = visual_q.transpose(1, 2) * scale
|
||||
|
||||
B, H, x_seqlens, K = visual_q.shape
|
||||
|
||||
x_ref_attn_maps = []
|
||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||||
|
||||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||
|
||||
for i in range(0, x_seqlens, chunk_size):
|
||||
end_i = min(i + chunk_size, x_seqlens)
|
||||
|
||||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||
|
||||
# Apply softmax
|
||||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||
attn_chunk = (attn_chunk - attn_max).exp()
|
||||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||
|
||||
# Apply mask and sum
|
||||
masked_attn = attn_chunk * ref_target_mask
|
||||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||
|
||||
del attn_chunk, masked_attn
|
||||
|
||||
# Average across heads
|
||||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||
x_ref_attn_maps.append(x_ref_attnmap)
|
||||
|
||||
del visual_q, ref_k
|
||||
|
||||
return torch.cat(x_ref_attn_maps, dim=0)
|
||||
|
||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||
"""Args:
|
||||
query (torch.tensor): B M H K
|
||||
key (torch.tensor): B M H K
|
||||
shape (tuple): (N_t, N_h, N_w)
|
||||
ref_target_masks: [B, N_h * N_w]
|
||||
"""
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_seqlens = N_h * N_w
|
||||
ref_k = ref_k[:, :x_seqlens]
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
|
||||
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_target_masks
|
||||
)
|
||||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||
|
||||
return x_ref_attn_maps / split_num
|
||||
|
||||
|
||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||
source_min, source_max = source_range
|
||||
new_min, new_max = target_range
|
||||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||
scaled = normalized * (new_max - new_min) + new_min
|
||||
return scaled
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
def get_audio_embeds(encoded_audio, audio_start, audio_end):
|
||||
audio_embs = []
|
||||
human_num = len(encoded_audio)
|
||||
audio_frames = encoded_audio[0].shape[0]
|
||||
|
||||
indices = (torch.arange(4 + 1) - 2) * 1
|
||||
|
||||
for human_idx in range(human_num):
|
||||
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
|
||||
pad_len = audio_end - audio_frames
|
||||
pad_shape = list(encoded_audio[human_idx].shape)
|
||||
pad_shape[0] = pad_len
|
||||
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
|
||||
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
|
||||
else:
|
||||
encoded_audio_in = encoded_audio[human_idx]
|
||||
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
|
||||
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
|
||||
audio_embs.append(audio_emb)
|
||||
|
||||
return torch.cat(audio_embs, dim=0)
|
||||
|
||||
|
||||
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
|
||||
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
|
||||
|
||||
first_frame_audio_emb_s = audio_embs[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_embs[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
|
||||
|
||||
middle_index = audio_proj.seq_len // 2
|
||||
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
|
||||
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||
audio_emb = torch.cat(audio_emb.split(1), dim=2)
|
||||
|
||||
return audio_emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = 10000
|
||||
|
||||
def precompute_freqs_cis_1d(self, pos_indices):
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||
freqs = freqs.to(pos_indices.device)
|
||||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, pos_indices):
|
||||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||
|
||||
x_ = x.float()
|
||||
|
||||
freqs_cis = freqs_cis.float().to(x.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||
|
||||
return x_.type_as(x)
|
||||
|
||||
class SingleStreamAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
expected_tokens = N_t * N_h * N_w
|
||||
actual_tokens = x.shape[1]
|
||||
x_extra = None
|
||||
|
||||
if actual_tokens != expected_tokens:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
|
||||
B = x.shape[0]
|
||||
S = N_h * N_w
|
||||
x = x.view(B * N_t, S, self.dim)
|
||||
|
||||
# get q for hidden_state
|
||||
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||||
|
||||
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||||
kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||||
|
||||
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# linear transform
|
||||
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||||
x = x.view(B, N_t * S, self.dim)
|
||||
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class SingleStreamMultiAttention(SingleStreamAttention):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
class_range: int = 24,
|
||||
class_interval: int = 4,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
# Rotary-embedding layout parameters
|
||||
self.class_interval = class_interval
|
||||
self.class_range = class_range
|
||||
self.max_humans = self.class_range // self.class_interval
|
||||
|
||||
# Constant bucket used for background tokens
|
||||
self.rope_bak = int(self.class_range // 2)
|
||||
|
||||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
shape=None,
|
||||
x_ref_attn_map=None
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
|
||||
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
|
||||
# Single-speaker fall-through
|
||||
if human_num <= 1:
|
||||
return super().forward(x, encoder_hidden_states, shape)
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_extra = None
|
||||
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
|
||||
# Query projection
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
|
||||
# Use `class_range` logic for 2 speakers
|
||||
rope_h1 = (0, self.class_interval)
|
||||
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||
rope_bak = int(self.class_range // 2)
|
||||
|
||||
# Normalize and scale attention maps for each speaker
|
||||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||
|
||||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||||
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||||
|
||||
# Token-wise speaker dominance
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||||
|
||||
# Apply rotary to Q
|
||||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
q = self.rope_1d(q, normalized_pos)
|
||||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Keys / Values
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||||
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||||
|
||||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Final attention
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# Linear projection
|
||||
x = x.reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
# Restore original layout
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkAudioProjModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len: int = 5,
|
||||
seq_len_vf: int = 12,
|
||||
blocks: int = 12,
|
||||
channels: int = 768,
|
||||
intermediate_dim: int = 512,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels
|
||||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.out_dim = out_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, audio_embeds, audio_embeds_vf):
|
||||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||
B, _, _, S, C = audio_embeds.shape
|
||||
|
||||
# process audio of first frame
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
# process audio of latter frame
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||
|
||||
# first projection
|
||||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||
|
||||
# second projection
|
||||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||
|
||||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
|
||||
|
||||
# normalization and reshape
|
||||
context_tokens = self.norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||
|
||||
|
||||
class MultiTalkGetAttnMapPatch:
|
||||
def __init__(self, ref_target_masks=None):
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x = kwargs["x"]
|
||||
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkCrossAttnPatch:
|
||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||
self.model_patch = model_patch
|
||||
self.audio_scale = audio_scale
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
block_idx = transformer_options.get("block_index", None)
|
||||
x = kwargs["x"]
|
||||
if block_idx is None:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||
|
||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||
norm_x, audio_embeds.to(x.dtype),
|
||||
shape=transformer_options["grid_sizes"],
|
||||
x_ref_attn_map=x_ref_attn_map
|
||||
)
|
||||
x = x + x_audio * self.audio_scale
|
||||
return x
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class MultiTalkApplyModelWrapper:
|
||||
def __init__(self, init_latents):
|
||||
self.init_latents = init_latents
|
||||
|
||||
def __call__(self, executor, x, *args, **kwargs):
|
||||
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
|
||||
samples = executor(x, *args, **kwargs)
|
||||
return samples
|
||||
|
||||
|
||||
class InfiniteTalkOuterSampleWrapper:
|
||||
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
|
||||
self.motion_frames_latent = motion_frames_latent
|
||||
self.model_patch = model_patch
|
||||
self.is_extend = is_extend
|
||||
|
||||
def __call__(self, executor, *args, **kwargs):
|
||||
model_patcher = executor.class_obj.model_patcher
|
||||
model_options = executor.class_obj.model_options
|
||||
process_latent_in = model_patcher.model.process_latent_in
|
||||
|
||||
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
|
||||
if self.motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
|
||||
|
||||
# run the sampling process
|
||||
result = executor(*args, **kwargs)
|
||||
|
||||
# insert motion frames before decoding
|
||||
if self.is_extend:
|
||||
overlap = self.motion_frames_latent.shape[2]
|
||||
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
return result
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.motion_frames_latent is not None:
|
||||
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
|
||||
return self
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
self._padding = 2 * self.padding[0]
|
||||
self.padding = (0, self.padding[1], self.padding[2])
|
||||
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
if cache_x is None and x.shape[2] == 1:
|
||||
#Fast path - the op will pad for use by truncating the weight
|
||||
#and save math on a pile of zeros.
|
||||
return super().forward(x, autopad="causal_zero")
|
||||
|
||||
if self._padding > 0:
|
||||
padding_needed = self._padding
|
||||
if cache_x is not None:
|
||||
cache_x = cache_x.to(x.device)
|
||||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||||
padding_shape = list(x.shape)
|
||||
padding_shape[2] = padding_needed
|
||||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
@@ -227,6 +234,7 @@ class Encoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
input_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@@ -245,7 +253,7 @@ class Encoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
@@ -331,6 +339,7 @@ class Decoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
output_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@@ -378,7 +387,7 @@ class Decoder3d(nn.Module):
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
@@ -449,6 +458,7 @@ class WanVAE(nn.Module):
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -460,19 +470,21 @@ class WanVAE(nn.Module):
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -492,10 +504,11 @@ class WanVAE(nn.Module):
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
@@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
@@ -322,6 +323,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@@ -48,6 +49,7 @@ import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -946,7 +948,7 @@ class GenmoMochi(BaseModel):
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -977,6 +979,60 @@ class LTXV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class LTXAV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
|
||||
audio_denoise_mask = None
|
||||
if denoise_mask is not None and "latent_shapes" in kwargs:
|
||||
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
|
||||
if len(denoise_mask) > 1:
|
||||
audio_denoise_mask = denoise_mask[1]
|
||||
denoise_mask = denoise_mask[0]
|
||||
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
|
||||
if audio_denoise_mask is not None:
|
||||
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
|
||||
|
||||
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
latent_shapes = kwargs.get("latent_shapes", None)
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
v_timestep = timestep
|
||||
a_timestep = timestep
|
||||
|
||||
if denoise_mask is not None:
|
||||
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||
if audio_denoise_mask is not None:
|
||||
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
|
||||
|
||||
return v_timestep, a_timestep
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
@@ -1092,9 +1148,31 @@ class CosmosPredict2(BaseModel):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
return latent_image / (1.0 - sigma)
|
||||
|
||||
class Anima(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
t5xxl_ids = kwargs.get("t5xxl_ids", None)
|
||||
t5xxl_weights = kwargs.get("t5xxl_weights", None)
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1110,10 +1188,39 @@ class Lumina2(BaseModel):
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
|
||||
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
|
||||
sigfeats = []
|
||||
for clip_vision_output in clip_vision_outputs:
|
||||
if clip_vision_output is not None:
|
||||
image_size = clip_vision_output.image_sizes[0]
|
||||
shape = clip_vision_output.last_hidden_state.shape
|
||||
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
|
||||
if len(sigfeats) > 0:
|
||||
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
|
||||
if ref_contexts is not None:
|
||||
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
@@ -1471,6 +1578,9 @@ class QwenImage(BaseModel):
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
@@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
@@ -251,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
dit_config["patch_size"] = 16
|
||||
dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1)
|
||||
dit_config["nerf_hidden_size"] = 64
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
@@ -259,8 +261,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "__x0__" in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
@@ -303,7 +307,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
@@ -428,8 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
@@ -439,8 +444,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||
dit_config["z_image_modulation"] = True
|
||||
dit_config["time_scale"] = 1000.0
|
||||
try:
|
||||
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
|
||||
except Exception:
|
||||
pass
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
return dit_config
|
||||
|
||||
@@ -542,6 +554,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos_predict2"
|
||||
if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "anima"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
@@ -616,6 +630,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||
dit_config["use_additional_t_cond"] = True
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
|
||||
@@ -22,10 +22,10 @@ from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -333,28 +333,42 @@ except:
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
|
||||
def aotriton_supported(gpu_arch):
|
||||
path = torch.__path__[0]
|
||||
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
|
||||
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
|
||||
if gpu_arch in gfx:
|
||||
return True
|
||||
if "{}x".format(gpu_arch[:-1]) in gfx:
|
||||
return True
|
||||
if "{}xx".format(gpu_arch[:-2]) in gfx:
|
||||
return True
|
||||
return False
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
@@ -453,7 +467,7 @@ def module_size(module):
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
@@ -1016,8 +1030,8 @@ NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia
|
||||
if is_nvidia():
|
||||
# Enable by default on Nvidia and AMD
|
||||
if is_nvidia() or is_amd():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
@@ -1123,6 +1137,16 @@ if not args.disable_pinned_memory:
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
@@ -1143,7 +1167,7 @@ def pin_memory(tensor):
|
||||
if not tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
size = tensor.nbytes
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
@@ -1155,6 +1179,9 @@ def pin_memory(tensor):
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
else:
|
||||
logging.warning("Pin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
@@ -1167,7 +1194,7 @@ def unpin_memory(tensor):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
size = tensor.nbytes
|
||||
|
||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||
if size_stored is None:
|
||||
@@ -1183,6 +1210,9 @@ def unpin_memory(tensor):
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
@@ -1485,6 +1515,16 @@ def supports_fp8_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_nvfp4_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
@@ -1523,6 +1563,10 @@ def soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@@ -454,6 +454,9 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
@@ -715,6 +718,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
m.comfy_force_cast_weights = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
@@ -787,11 +791,12 @@ class ModelPatcher:
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
191
comfy/ops.py
191
comfy/ops.py
@@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
if isinstance(input, QuantizedTensor):
|
||||
dtype = input._layout_params["orig_dtype"]
|
||||
dtype = input.params.orig_dtype
|
||||
else:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
@@ -203,7 +203,9 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
|
||||
if autopad == "causal_zero":
|
||||
weight = weight[:, :, -input.shape[2]:, :, :]
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
@@ -212,15 +214,15 @@ class disable_weight_init:
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
def forward_comfy_cast_weights(self, input, autopad=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
x = self._conv_forward(input, weight, bias, autopad=autopad)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -412,26 +414,34 @@ def fp8_linear(self, input):
|
||||
return None
|
||||
|
||||
input_dtype = input.dtype
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
if input.ndim != 2:
|
||||
return None
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input_fp8 = input.to(dtype).contiguous()
|
||||
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
return o
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
return None
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
if tensor_3d:
|
||||
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
|
||||
|
||||
return o
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
@@ -477,14 +487,20 @@ if CUBLAS_IS_AVAILABLE:
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
from .quant_ops import (
|
||||
QuantizedTensor,
|
||||
QUANT_ALGOS,
|
||||
TensorCoreFP8Layout,
|
||||
get_layout_class,
|
||||
)
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
@@ -509,10 +525,21 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
@@ -521,7 +548,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
@@ -533,34 +561,58 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(weight_scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.to(device)
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
# FP8: single tensor scale
|
||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
|
||||
if scale is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
if tensor_scale is None or block_scale is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=tensor_scale,
|
||||
block_scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue # Already handled above
|
||||
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
@@ -575,13 +627,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
if input_scale is not None:
|
||||
sd["{}input_scale".format(prefix)] = input_scale
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
@@ -596,12 +664,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
@@ -611,7 +700,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
@@ -638,10 +728,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@@ -1,580 +1,174 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
QuantizedTensor,
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
)
|
||||
_CK_AVAILABLE = True
|
||||
if torch.version.cuda is None:
|
||||
ck.registry.disable("cuda")
|
||||
else:
|
||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||
_CK_AVAILABLE = False
|
||||
|
||||
class QuantizedTensor:
|
||||
pass
|
||||
|
||||
class _CKFp8Layout:
|
||||
pass
|
||||
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
import comfy.float
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
"""
|
||||
Decorator to register a layout-specific operation handler.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
Example:
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
# FP8-specific linear implementation
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
if torch_op not in _LAYOUT_REGISTRY:
|
||||
_LAYOUT_REGISTRY[torch_op] = {}
|
||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_generic_util(torch_op):
|
||||
"""
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
# Works for any layout
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_GENERIC_UTILS[torch_op] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_layout_from_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg._layout_type
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, QuantizedTensor):
|
||||
return item._layout_type
|
||||
return None
|
||||
|
||||
|
||||
def _move_layout_params_to_device(params, device):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.to(device=device)
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def _copy_layout_params(params):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.clone()
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||
for k, v in src.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
dst[k].copy_(v, non_blocking=non_blocking)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
|
||||
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_qdata"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
Tensor unflattening protocol for proper device movement.
|
||||
Reconstructs the QuantizedTensor after device movement.
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
def dequant_arg(arg):
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg.dequantize()
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
return arg
|
||||
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._qdata.data_ptr()
|
||||
|
||||
def is_pinned(self):
|
||||
return self._qdata.is_pinned()
|
||||
|
||||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
def storage(self):
|
||||
return self._qdata.storage()
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# FP8 Layouts with Comfy-Specific Extensions
|
||||
# ==============================================================================
|
||||
|
||||
def _create_transformed_qtensor(qt, transform_fn):
|
||||
new_data = transform_fn(qt._qdata)
|
||||
new_params = _copy_layout_params(qt._layout_params)
|
||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
FP8_DTYPE = None # Must be overridden in subclass
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
# Normalize device for comparison
|
||||
if isinstance(target_device, str):
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
if target_dtype is not None:
|
||||
new_params["orig_dtype"] = target_dtype
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.clone.default)
|
||||
def generic_clone(func, args, kwargs):
|
||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||
def generic_to_copy(func, args, kwargs):
|
||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
op_name="_to_copy"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||
def generic_to_dtype_layout(func, args, kwargs):
|
||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
target_layout=kwargs.get('layout', None),
|
||||
op_name="to"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.copy_.default)
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
non_blocking = args[2] if len(args) > 2 else False
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
return qt_dest
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype)
|
||||
def generic_to_dtype(func, args, kwargs):
|
||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||
src = args[0]
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||
src._layout_params["orig_dtype"] = target_dtype
|
||||
return src
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||
def generic_empty_like(func, args, kwargs):
|
||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
# Create empty tensor with same shape and dtype as the quantized data
|
||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||
|
||||
# Handle device transfer for layout params
|
||||
target_device = kwargs.get('device', new_qdata.device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
|
||||
# Update orig_dtype if dtype is specified
|
||||
new_params['orig_dtype'] = hp_dtype
|
||||
|
||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
class TensorCoreFP8Layout(QuantizedLayout):
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if cls.FP8_DTYPE is None:
|
||||
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if isinstance(scale, str) and scale == "recalculate":
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
|
||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||
tensor_info = torch.finfo(tensor.dtype)
|
||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||
|
||||
if scale is not None:
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
if scale is None:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
|
||||
else:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
||||
|
||||
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
|
||||
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
||||
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return tensor, layout_params
|
||||
params = cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
block_scale=block_scale,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e5m2
|
||||
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Registry
|
||||
# ==============================================================================
|
||||
|
||||
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
|
||||
},
|
||||
"float8_e5m2": {
|
||||
"storage_t": torch.float8_e5m2,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
|
||||
},
|
||||
"nvfp4": {
|
||||
"storage_t": torch.uint8,
|
||||
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
||||
"group_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
LAYOUTS = {
|
||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||
}
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
# ==============================================================================
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||
def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
if len(plain_input.shape) == 2:
|
||||
tensor_2d = True
|
||||
plain_input = plain_input.unsqueeze(1)
|
||||
|
||||
input_shape = plain_input.shape
|
||||
if len(input_shape) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
output_scale = scale_a * scale_b
|
||||
output_params = {
|
||||
'scale': output_scale,
|
||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||
}
|
||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
# Case 2: DQ Fallback
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
input_tensor = input_tensor.dequantize()
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input.contiguous(),
|
||||
plain_weight,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||
def fp8_addmm(func, args, kwargs):
|
||||
input_tensor = args[1]
|
||||
weight = args[2]
|
||||
bias = args[0]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
if isinstance(args[2], QuantizedTensor):
|
||||
a[2] = args[2].dequantize()
|
||||
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||
def fp8_mm(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||
def fp8_func(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
ar = list(args)
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
__all__ = [
|
||||
"QuantizedTensor",
|
||||
"QuantizedLayout",
|
||||
"TensorCoreFP8Layout",
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
@@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if torch.count_nonzero(latent_image) == 0:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
|
||||
|
||||
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||
latent_image = latent_image.unsqueeze(2)
|
||||
return latent_image
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user