mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 03:00:03 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13ee23042f |
@@ -53,16 +53,6 @@ try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
@@ -76,10 +66,8 @@ if branch is None:
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("fetching.") # noqa: T201
|
||||
for remote in repo.remotes:
|
||||
if remote.name == "origin":
|
||||
remote.fetch()
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
@@ -161,4 +149,3 @@ try:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
If you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
||||
|
||||
run_amd_gpu_disable_smart_memory.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
To update the ComfyUI code: update\update_comfyui.bat
|
||||
|
||||
|
||||
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
|
||||
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
pause
|
||||
@@ -1,3 +0,0 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
@@ -1,3 +0,0 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
@@ -1,3 +0,0 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -8,15 +8,13 @@ body:
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
- **1:** You are running the latest version of ComfyUI.
|
||||
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||
`--disable-all-custom-nodes` command line argument.
|
||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
@@ -1,21 +0,0 @@
|
||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||
|
||||
## API Node PR Checklist
|
||||
|
||||
### Scope
|
||||
- [ ] **Is API Node Change**
|
||||
|
||||
### Pricing & Billing
|
||||
- [ ] **Need pricing update**
|
||||
- [ ] **No pricing update**
|
||||
|
||||
If **Need pricing update**:
|
||||
- [ ] Metronome rate cards updated
|
||||
- [ ] Auto‑billing tests updated and passing
|
||||
|
||||
### QA
|
||||
- [ ] **QA done**
|
||||
- [ ] **QA not required**
|
||||
|
||||
### Comms
|
||||
- [ ] Informed **Kosinkadink**
|
||||
58
.github/workflows/api-node-template.yml
vendored
58
.github/workflows/api-node-template.yml
vendored
@@ -1,58 +0,0 @@
|
||||
name: Append API Node PR template
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, synchronize, ready_for_review]
|
||||
paths:
|
||||
- 'comfy_api_nodes/**' # only run if these files changed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
inject:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Ensure template exists and append to PR body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { owner, repo } = context.repo;
|
||||
const number = context.payload.pull_request.number;
|
||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||
|
||||
let templateText;
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner,
|
||||
repo,
|
||||
path: templatePath,
|
||||
ref: pr.base.ref
|
||||
});
|
||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||
templateText = buf.toString('utf8');
|
||||
} catch (e) {
|
||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enforce the presence of the marker inside the template (for idempotence)
|
||||
if (!templateText.includes(marker)) {
|
||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the PR already contains the marker, do not append again.
|
||||
const body = pr.body || '';
|
||||
if (body.includes(marker)) {
|
||||
core.info('Template already present in PR body; nothing to inject.');
|
||||
return;
|
||||
}
|
||||
|
||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||
core.notice('API Node template appended to PR description.');
|
||||
78
.github/workflows/release-stable-all.yml
vendored
78
.github/workflows/release-stable-all.yml
vendored
@@ -1,78 +0,0 @@
|
||||
name: "Release Stable All Portable Versions"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
release_nvidia_default:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu128:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu128"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu128"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu128"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu126"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu126"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu126"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
rel_extra_name: ""
|
||||
test_release: false
|
||||
secrets: inherit
|
||||
25
.github/workflows/ruff.yml
vendored
25
.github/workflows/ruff.yml
vendored
@@ -21,28 +21,3 @@ jobs:
|
||||
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
||||
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint comfy_api_nodes
|
||||
|
||||
98
.github/workflows/stable-release.yml
vendored
98
.github/workflows/stable-release.yml
vendored
@@ -2,53 +2,17 @@
|
||||
name: "Release Stable Version"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
cu:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu129"
|
||||
default: "129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
@@ -59,21 +23,7 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
|
||||
jobs:
|
||||
package_comfy_windows:
|
||||
@@ -92,15 +42,15 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
cu${{ inputs.cu }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
||||
- shell: bash
|
||||
run: |
|
||||
mv ${{ inputs.cache_tag }}_python_deps.tar ../
|
||||
mv cu${{ inputs.cu }}_python_deps.tar ../
|
||||
mv update_comfyui_and_python_dependencies.bat ../
|
||||
cd ..
|
||||
tar xf ${{ inputs.cache_tag }}_python_deps.tar
|
||||
tar xf cu${{ inputs.cu }}_python_deps.tar
|
||||
pwd
|
||||
ls
|
||||
|
||||
@@ -115,19 +65,12 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
fi
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
|
||||
@@ -142,18 +85,14 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
- shell: bash
|
||||
if: ${{ inputs.test_release }}
|
||||
run: |
|
||||
cd ..
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
@@ -162,9 +101,10 @@ jobs:
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
tag_name: ${{ inputs.git_tag }}
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
draft: true
|
||||
overwrite_files: true
|
||||
|
||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
23
.github/workflows/test-ci.yml
vendored
23
.github/workflows/test-ci.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
@@ -20,18 +19,16 @@ jobs:
|
||||
test-stable:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1 # This forces sequential execution
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
@@ -75,17 +72,15 @@ jobs:
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1 # This forces sequential execution
|
||||
matrix:
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
|
||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@@ -2,9 +2,9 @@ name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
8
.github/workflows/test-launch.yml
vendored
8
.github/workflows/test-launch.yml
vendored
@@ -2,9 +2,9 @@ name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -32,9 +32,7 @@ jobs:
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
|
||||
cat console_output_filtered.log
|
||||
if grep -qE "Exception|Error" console_output_filtered.log; then
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
6
.github/workflows/test-unit.yml
vendored
6
.github/workflows/test-unit.yml
vendored
@@ -2,15 +2,15 @@ name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2022, macos-latest]
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
|
||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "130"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -56,8 +56,7 @@ jobs:
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
name: "Windows Release dependencies Manual"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
torch_dependencies:
|
||||
description: 'torch dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
|
||||
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
|
||||
|
||||
- uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
25
CODEOWNERS
25
CODEOWNERS
@@ -1,2 +1,25 @@
|
||||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
|
||||
168
QUANTIZATION.md
168
QUANTIZATION.md
@@ -1,168 +0,0 @@
|
||||
# The Comfy guide to Quantization
|
||||
|
||||
|
||||
## How does quantization work?
|
||||
|
||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||
|
||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||
|
||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||
|
||||
```
|
||||
absmax = max(abs(tensor))
|
||||
scale = amax / max_dynamic_range_low_precision
|
||||
|
||||
# Quantization
|
||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||
|
||||
# De-Quantization
|
||||
tensor_dq = tensor_q.to(fp16) * scale
|
||||
|
||||
tensor_dq ~ tensor
|
||||
```
|
||||
|
||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||
|
||||
|
||||
## Quantization in Comfy
|
||||
|
||||
```
|
||||
QuantizedTensor (torch.Tensor subclass)
|
||||
↓ __torch_dispatch__
|
||||
Two-Level Registry (generic + layout handlers)
|
||||
↓
|
||||
MixedPrecisionOps + Metadata Detection
|
||||
```
|
||||
|
||||
### Representation
|
||||
|
||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||
|
||||
A `Layout` class defines how a specific quantization format behaves:
|
||||
- Required parameters
|
||||
- Quantize method
|
||||
- De-Quantize method
|
||||
|
||||
```python
|
||||
from comfy.quant_ops import QuantizedLayout
|
||||
|
||||
class MyLayout(QuantizedLayout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs):
|
||||
# Convert to quantized format
|
||||
qdata = ...
|
||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
return qdata.to(orig_dtype) * scale
|
||||
```
|
||||
|
||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||
|
||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||
```python
|
||||
from comfy.quant_ops import register_layout_op
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||
def my_linear(func, args, kwargs):
|
||||
# Extract tensors, call optimized kernel
|
||||
...
|
||||
```
|
||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||
|
||||
**Architecture:**
|
||||
|
||||
```python
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||
```
|
||||
|
||||
**Key mechanism:**
|
||||
|
||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||
- If the layer name **is** in `_layer_quant_config`:
|
||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||
- Load associated quantization parameters (scales, block_size, etc.)
|
||||
|
||||
**Why it's needed:**
|
||||
|
||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||
|
||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||
|
||||
|
||||
## Checkpoint Format
|
||||
|
||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||
|
||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||
|
||||
### Scaling Parameters details
|
||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||
- **weight_scale**: quantization scalers for the weights
|
||||
- **weight_scale_2**: global scalers in the context of double scaling
|
||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||
- **input_scale**: quantization scalers for the activations
|
||||
|
||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||
|
||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||
|
||||
### Quantization Metadata
|
||||
|
||||
The metadata stored alongside the checkpoint contains:
|
||||
- **format_version**: String to define a version of the standard
|
||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Creating Quantized Checkpoints
|
||||
|
||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||
|
||||
### Weight Quantization
|
||||
|
||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||
|
||||
### Calibration (for Activation Quantization)
|
||||
|
||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||
|
||||
1. **Collect statistics**: Run inference on N representative samples
|
||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
97
README.md
97
README.md
@@ -67,8 +67,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
@@ -81,7 +79,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@@ -115,14 +112,10 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
@@ -179,20 +172,10 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@@ -208,13 +191,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -223,36 +200,18 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
Put your VAE in: models/vae
|
||||
|
||||
|
||||
### AMD GPUs (Linux)
|
||||
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
|
||||
|
||||
RDNA 3 (RX 7000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
|
||||
|
||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
|
||||
|
||||
RDNA 4 (RX 9000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
@@ -262,15 +221,19 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@@ -301,6 +264,12 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||
|
||||
#### DirectML (AMD Cards on Windows)
|
||||
|
||||
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
|
||||
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
@@ -325,32 +294,6 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||
|
||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install the manager dependencies:
|
||||
```bash
|
||||
pip install -r manager_requirements.txt
|
||||
```
|
||||
|
||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||
```bash
|
||||
python main.py --enable-manager
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
||||
@@ -58,13 +58,8 @@ class InternalRoutes:
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
@@ -10,8 +10,7 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from typing import TypedDict, Optional
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
@@ -43,7 +42,6 @@ def get_installed_frontend_version():
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
@@ -65,7 +63,6 @@ def get_required_frontend_version():
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
@@ -206,37 +203,6 @@ class FrontendManager:
|
||||
"""Get the required frontend package version."""
|
||||
return get_required_frontend_version()
|
||||
|
||||
@classmethod
|
||||
def get_installed_templates_version(cls) -> str:
|
||||
"""Get the currently installed workflow templates package version."""
|
||||
try:
|
||||
templates_version_str = version("comfyui-workflow-templates")
|
||||
return templates_version_str
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
@@ -258,54 +224,7 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
def templates_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@@ -324,7 +243,6 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@@ -441,17 +359,3 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
@@ -44,7 +44,7 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
@@ -55,7 +55,7 @@ class ModelFileManager:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@@ -59,9 +59,6 @@ class UserManager():
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
# Block System Users (use same error message to prevent probing)
|
||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
@@ -69,16 +66,15 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = folder_paths.get_user_directory()
|
||||
root_dir = user_directory
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
user_root = folder_paths.get_public_user_directory(user)
|
||||
if user_root is None:
|
||||
return None
|
||||
path = user_root
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
@@ -105,11 +101,7 @@ class UserManager():
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
@@ -140,10 +132,7 @@ class UserManager():
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
try:
|
||||
user_id = self.add_user(username)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
user_id = self.add_user(username)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata")
|
||||
@@ -435,7 +424,7 @@ class UserManager():
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(dest, str):
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
|
||||
@@ -413,8 +413,7 @@ class ControlNet(nn.Module):
|
||||
out_middle = []
|
||||
|
||||
if self.num_classes is not None:
|
||||
if y is None:
|
||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
|
||||
@@ -97,13 +97,6 @@ class LatentPreviewMethod(enum.Enum):
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
@@ -112,7 +105,6 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@@ -128,12 +120,6 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
@@ -144,8 +130,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -160,9 +145,7 @@ class PerformanceFeature(enum.Enum):
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
@@ -174,14 +157,13 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
||||
@@ -2,25 +2,6 @@ import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -16,7 +17,24 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
@@ -55,7 +73,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
|
||||
@@ -51,43 +51,32 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
window[idx] = full[idx]
|
||||
return window.to(device)
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
def get_region_index(self, num_regions: int) -> int:
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
@@ -105,8 +94,7 @@ class ContextFuseMethod:
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@@ -115,18 +103,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -140,11 +123,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
@@ -167,38 +145,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# Handle vace_context (temporal dim is 3)
|
||||
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
vace_cond = cond_value.cond
|
||||
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
@@ -211,7 +164,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
@@ -220,7 +173,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
@@ -297,8 +250,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
@@ -334,28 +287,6 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
)
|
||||
|
||||
|
||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||
model_options = extra_args.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is None:
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
"ContextWindows_sampler_sample",
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@@ -607,29 +538,3 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
|
||||
|
||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||
logging.info("Context windows: Applying FreeNoise")
|
||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||
latent_video_length = noise.shape[dim]
|
||||
delta = context_length - context_overlap
|
||||
|
||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||
place_idx = start_idx + context_length
|
||||
|
||||
actual_delta = min(delta, latent_video_length - place_idx)
|
||||
if actual_delta <= 0:
|
||||
break
|
||||
|
||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||
|
||||
source_slice = [slice(None)] * noise.ndim
|
||||
source_slice[dim] = list_idx
|
||||
target_slice = [slice(None)] * noise.ndim
|
||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||
|
||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||
|
||||
return noise
|
||||
|
||||
@@ -310,13 +310,11 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
@@ -352,13 +350,12 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
|
||||
@@ -527,8 +527,7 @@ class HookKeyframeGroup:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else:
|
||||
break
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
||||
@@ -74,9 +74,6 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
if x.device == torch.device("cpu"):
|
||||
seed += 1
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
@@ -1560,13 +1557,10 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@@ -1606,14 +1600,8 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
@@ -1621,17 +1609,6 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
@@ -1779,7 +1756,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x_pred = denoised
|
||||
x = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
@@ -1800,7 +1777,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x_pred
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -6,7 +6,6 @@ class LatentFormat:
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -179,54 +178,6 @@ class Flux(SD3):
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
[0.0058, 0.0113, 0.0073],
|
||||
[0.0495, 0.0443, 0.0836],
|
||||
[-0.0099, 0.0096, 0.0644],
|
||||
[0.2144, 0.3009, 0.3652],
|
||||
[0.0166, -0.0039, -0.0054],
|
||||
[0.0157, 0.0103, -0.0160],
|
||||
[-0.0398, 0.0902, -0.0235],
|
||||
[-0.0052, 0.0095, 0.0109],
|
||||
[-0.3527, -0.2712, -0.1666],
|
||||
[-0.0301, -0.0356, -0.0180],
|
||||
[-0.0107, 0.0078, 0.0013],
|
||||
[0.0746, 0.0090, -0.0941],
|
||||
[0.0156, 0.0169, 0.0070],
|
||||
[-0.0034, -0.0040, -0.0114],
|
||||
[0.0032, 0.0181, 0.0080],
|
||||
[-0.0939, -0.0008, 0.0186],
|
||||
[0.0018, 0.0043, 0.0104],
|
||||
[0.0284, 0.0056, -0.0127],
|
||||
[-0.0024, -0.0022, -0.0030],
|
||||
[0.1207, -0.0026, 0.0065],
|
||||
[0.0128, 0.0101, 0.0142],
|
||||
[0.0137, -0.0072, -0.0007],
|
||||
[0.0095, 0.0092, -0.0059],
|
||||
[0.0000, -0.0077, -0.0049],
|
||||
[-0.0465, -0.0204, -0.0312],
|
||||
[0.0095, 0.0012, -0.0066],
|
||||
[0.0290, -0.0034, 0.0025],
|
||||
[0.0220, 0.0169, -0.0048],
|
||||
[-0.0332, -0.0457, -0.0468],
|
||||
[-0.0085, 0.0389, 0.0609],
|
||||
[-0.0076, 0.0003, -0.0043],
|
||||
[-0.0111, -0.0460, -0.0614],
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@@ -407,11 +358,6 @@ class LTXV(LatentFormat):
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class LTXAV(LTXV):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = None
|
||||
self.latent_rgb_factors_bias = None
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
@@ -436,7 +382,6 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
taesd_decoder_name = "taehv"
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
@@ -500,7 +445,7 @@ class Wan21(LatentFormat):
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = "lighttaew2_1"
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
@@ -571,7 +516,6 @@ class Wan22(Wan21):
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.taesd_decoder_name = "lighttaew2_2"
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
@@ -667,67 +611,6 @@ class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
[ 0.0186, 0.0531, -0.0138],
|
||||
[-0.0031, 0.0051, 0.0288],
|
||||
[ 0.0110, 0.0556, 0.0432],
|
||||
[-0.0041, -0.0023, -0.0485],
|
||||
[ 0.0530, 0.0413, 0.0253],
|
||||
[ 0.0283, 0.0251, 0.0339],
|
||||
[ 0.0277, -0.0372, -0.0093],
|
||||
[ 0.0393, 0.0944, 0.1131],
|
||||
[ 0.0020, 0.0251, 0.0037],
|
||||
[-0.0017, 0.0012, 0.0234],
|
||||
[ 0.0468, 0.0436, 0.0203],
|
||||
[ 0.0354, 0.0439, -0.0233],
|
||||
[ 0.0090, 0.0123, 0.0346],
|
||||
[ 0.0382, 0.0029, 0.0217],
|
||||
[ 0.0261, -0.0300, 0.0030],
|
||||
[-0.0088, -0.0220, -0.0283],
|
||||
[-0.0272, -0.0121, -0.0363],
|
||||
[-0.0664, -0.0622, 0.0144],
|
||||
[ 0.0414, 0.0479, 0.0529],
|
||||
[ 0.0355, 0.0612, -0.0247],
|
||||
[ 0.0147, 0.0264, 0.0174],
|
||||
[ 0.0438, 0.0038, 0.0542],
|
||||
[ 0.0431, -0.0573, -0.0033],
|
||||
[-0.0162, -0.0211, -0.0406],
|
||||
[-0.0487, -0.0295, -0.0393],
|
||||
[ 0.0005, -0.0109, 0.0253],
|
||||
[ 0.0296, 0.0591, 0.0353],
|
||||
[ 0.0119, 0.0181, -0.0306],
|
||||
[-0.0085, -0.0362, 0.0229],
|
||||
[ 0.0005, -0.0106, 0.0242]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@@ -23,6 +23,8 @@ class MusicDCAE(torch.nn.Module):
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
@@ -35,6 +37,10 @@ class MusicDCAE(torch.nn.Module):
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
@@ -67,8 +73,10 @@ class MusicDCAE(torch.nn.Module):
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
@@ -83,7 +91,9 @@ class MusicDCAE(torch.nn.Module):
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
@@ -91,6 +101,7 @@ class MusicDCAE(torch.nn.Module):
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
# TODO: remove this in a few months
|
||||
SingleStreamBlock = None
|
||||
DoubleStreamBlock = None
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@@ -48,6 +48,124 @@ class Approximator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -11,12 +11,12 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
@@ -40,8 +40,7 @@ class ChromaParams:
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -91,7 +90,6 @@ class Chroma(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -100,7 +98,7 @@ class Chroma(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
@@ -180,10 +178,7 @@ class Chroma(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@@ -226,10 +221,7 @@ class Chroma(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -10,10 +10,12 @@ from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
@@ -37,7 +39,7 @@ class ChromaRadianceParams(ChromaParams):
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@@ -87,6 +89,7 @@ class ChromaRadiance(Chroma):
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@@ -94,7 +97,6 @@ class ChromaRadiance(Chroma):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -107,7 +109,6 @@ class ChromaRadiance(Chroma):
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
@@ -159,9 +160,6 @@ class ChromaRadiance(Chroma):
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@@ -191,15 +189,15 @@ class ChromaRadiance(Chroma):
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
@@ -242,8 +240,17 @@ class ChromaRadiance(Chroma):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
@@ -270,7 +277,7 @@ class ChromaRadiance(Chroma):
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
@@ -279,12 +286,6 @@ class ChromaRadiance(Chroma):
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@@ -325,11 +326,4 @@ class ChromaRadiance(Chroma):
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
@@ -48,44 +48,15 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
class YakMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
if yak_mlp:
|
||||
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
@@ -109,14 +80,14 @@ class QKNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -127,11 +98,11 @@ class ModulationOut:
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
if vec.ndim == 2:
|
||||
@@ -158,107 +129,77 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor
|
||||
|
||||
|
||||
class SiLUActivation(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return self.gate_fn(x1) * x2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.modulation = modulation
|
||||
|
||||
if self.modulation:
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
del img_attn
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
del txt_attn
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
@@ -279,10 +220,6 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
yak_mlp=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@@ -294,55 +231,30 @@ class SingleStreamBlock(nn.Module):
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
self.yak_mlp = yak_mlp
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
if self.yak_mlp:
|
||||
self.mlp_hidden_dim_first *= 2
|
||||
self.mlp_act = nn.SiLU()
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.modulation = None
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
if self.modulation:
|
||||
mod, _ = self.modulation(vec)
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
else:
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -350,11 +262,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
|
||||
@@ -4,16 +4,23 @@ from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
@@ -28,20 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -15,8 +15,6 @@ from .layers import (
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -35,14 +33,6 @@ class FluxParams:
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
yak_mlp: bool = False
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@@ -68,22 +58,13 @@ class Flux(nn.Module):
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -92,10 +73,6 @@ class Flux(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
yak_mlp=params.yak_mlp,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -104,30 +81,13 @@ class Flux(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if params.global_modulation:
|
||||
self.double_stream_modulation_img = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.double_stream_modulation_txt = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.single_stream_modulation = Modulation(
|
||||
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -143,6 +103,9 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -155,19 +118,9 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
if self.vector_in is not None:
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.txt_norm is not None:
|
||||
txt = self.txt_norm(txt)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
@@ -183,10 +136,7 @@ class Flux(nn.Module):
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -227,13 +177,7 @@ class Flux(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -263,10 +207,10 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@@ -278,22 +222,10 @@ class Flux(nn.Module):
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
steps_h = h_len
|
||||
steps_w = w_len
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
index += rope_options.get("shift_t", 0.0)
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
@@ -309,16 +241,16 @@ class Flux(nn.Module):
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img, img_ids = self.process_img(x)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
index += self.params.ref_index_scale
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
@@ -342,12 +274,7 @@ class Flux(nn.Module):
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.txt_ids_dims) > 0:
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@@ -6,6 +6,7 @@ import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@@ -41,9 +42,6 @@ class HunyuanVideoParams:
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
meanflow_sum: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@@ -159,10 +157,7 @@ class TokenRefiner(nn.Module):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
@@ -201,15 +196,11 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -275,18 +266,6 @@ class HunyuanVideo(nn.Module):
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@@ -297,7 +276,6 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@@ -318,7 +296,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||
vec = (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
@@ -353,31 +331,12 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@@ -390,10 +349,7 @@ class HunyuanVideo(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -415,10 +371,7 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -477,14 +430,14 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
@@ -492,5 +445,5 @@ class HunyuanVideo(nn.Module):
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management
|
||||
import model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
class SRModel3DV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 6,
|
||||
global_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||
self.global_residual = bool(global_residual)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
y = self.in_conv(x)
|
||||
for blk in self.blocks:
|
||||
y = blk(y)
|
||||
y = self.out_conv(y)
|
||||
if self.global_residual and (y.shape == residual.shape):
|
||||
y = y + residual
|
||||
return y
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: tuple[int, ...],
|
||||
num_res_blocks: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.block_out_channels = block_out_channels
|
||||
self.z_channels = z_channels
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_shortcut=False,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Args:
|
||||
z: (B, C, T, H, W)
|
||||
target_shape: (H, W)
|
||||
"""
|
||||
# z to block_in
|
||||
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# upsampling
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
return out
|
||||
|
||||
UPSAMPLERS = {
|
||||
"720p": SRModel3DV2,
|
||||
"1080p": Upsampler,
|
||||
}
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
@@ -1,13 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -16,25 +14,23 @@ class RMS_norm(nn.Module):
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||
def __init__(self, ic, oc, tds=True):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
|
||||
self.refiner_vae = refiner_vae
|
||||
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
|
||||
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = self.conv(x)
|
||||
|
||||
if self.tds:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
@@ -42,7 +38,14 @@ class DnSmpl(nn.Module):
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
h = h[:, :, 1:, :, :]
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nf = frms // r1
|
||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@@ -50,49 +53,49 @@ class DnSmpl(nn.Module):
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
x = x[:, :, 1:, :, :]
|
||||
xn = x[:, :, 1:, :, :]
|
||||
b, ci, frms, ht, wd = xn.shape
|
||||
nf = frms // r1
|
||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xn.shape
|
||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
if h.shape[2] == 0:
|
||||
return hf + xf
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = sc.shape
|
||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
return h + sc
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||
def __init__(self, ic, oc, tus=True):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||
self.refiner_vae = refiner_vae
|
||||
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
|
||||
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
h = self.conv(x)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
if self.tus:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
@@ -101,7 +104,14 @@ class UpSmpl(nn.Module):
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
h = h[:, :, 1:, :, :]
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@@ -112,147 +122,109 @@ class UpSmpl(nn.Module):
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
x = x[:, :, 1:, :, :]
|
||||
xn = x[:, :, 1:, :, :]
|
||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = xn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = sc.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = x.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
return h + sc
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.ffactor_temporal = ffactor_temporal
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
norm_op = Normalize
|
||||
|
||||
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
|
||||
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||
|
||||
def forward(self, x):
|
||||
if not self.refiner_vae and x.shape[2] == 1:
|
||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||
x = self.conv_in(x)
|
||||
|
||||
if self.refiner_vae:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.ffactor_temporal:
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
norm_op = Normalize
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
self.conv_in = VideoConv3d(z_channels, ch, 3)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
@@ -263,51 +235,33 @@ class Decoder(nn.Module):
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, 3)
|
||||
|
||||
def forward(self, z):
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
if self.refiner_vae:
|
||||
x = torch.split(x, 2, dim=2)
|
||||
else:
|
||||
x = [ x ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
out = out[:, :, -1:]
|
||||
|
||||
return out
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
|
||||
@@ -1,413 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
original_dims = x.ndim
|
||||
if original_dims == 4:
|
||||
x = x.unsqueeze(2)
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
if original_dims == 4:
|
||||
out = out.squeeze(2)
|
||||
return out
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
@@ -1,837 +0,0 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
v_dim,
|
||||
a_dim,
|
||||
v_heads,
|
||||
a_heads,
|
||||
vd_head,
|
||||
ad_head,
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn1 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=v_context_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn2 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=a_context_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Video, K,V: Audio
|
||||
self.audio_to_video_attn = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
self.video_to_audio_attn = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=v_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.audio_ff = FeedForward(
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||
).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
num_scale_shift_values: int = 4,
|
||||
):
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
|
||||
vx, ax = x
|
||||
run_ax = run_ax and ax.numel() > 0
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
|
||||
|
||||
return vx, ax
|
||||
|
||||
|
||||
class LTXAVModel(LTXVModel):
|
||||
"""LTXAV model for audio-video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
audio_in_channels=128,
|
||||
cross_attention_dim=4096,
|
||||
audio_cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
audio_attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
audio_num_attention_heads=32,
|
||||
caption_channels=3840,
|
||||
num_layers=48,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store audio-specific parameters
|
||||
self.audio_in_channels = audio_in_channels
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
self.audio_out_channels = audio_in_channels
|
||||
|
||||
# Audio-specific constants
|
||||
self.num_audio_channels = 8
|
||||
self.audio_frequency_bins = 16
|
||||
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXAV-specific components."""
|
||||
# Audio-specific projections
|
||||
self.audio_patchify_proj = self.operations.Linear(
|
||||
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicAVTransformerBlock(
|
||||
v_dim=self.inner_dim,
|
||||
a_dim=self.audio_inner_dim,
|
||||
v_heads=self.num_attention_heads,
|
||||
a_heads=self.audio_num_attention_heads,
|
||||
vd_head=self.attention_head_dim,
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXAV."""
|
||||
# Video output components
|
||||
super()._init_output_components(device, dtype)
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||
)
|
||||
self.audio_norm_out = self.operations.LayerNorm(
|
||||
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.audio_proj_out = self.operations.Linear(
|
||||
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||
)
|
||||
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||
|
||||
def separate_audio_and_video_latents(self, x, audio_length):
|
||||
"""Separate audio and video latents from combined input."""
|
||||
# vx = x[:, : self.in_channels]
|
||||
# ax = x[:, self.in_channels :]
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||
#
|
||||
# ax = ax.reshape(
|
||||
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||
# )
|
||||
|
||||
vx = x[0]
|
||||
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||
device=vx.device, dtype=vx.dtype
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||
if ax.numel() == 0:
|
||||
return vx
|
||||
else:
|
||||
return [vx, ax]
|
||||
"""Recombine audio and video latents for output."""
|
||||
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# # pad to f x h x w of the video latents
|
||||
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||
# if target_shape is None:
|
||||
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||
# else:
|
||||
# repetitions = target_shape[1] - vx.shape[1]
|
||||
# padded_len = repetitions * divisor
|
||||
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||
# return torch.cat([vx, ax], dim=1)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||
v_embedded_timestep = v_embedded_timestep.view(
|
||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep.flatten() * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep.flatten() * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(
|
||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||
)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
]
|
||||
cross_av_timestep_ss = list(
|
||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
v_pixel_coords = pixel_coords[0]
|
||||
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||
|
||||
a_latent_coords = pixel_coords[1]
|
||||
a_pe = self._precompute_freqs_cis(
|
||||
a_latent_coords,
|
||||
dim=self.audio_inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||
max_pos = max(
|
||||
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||
)
|
||||
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||
v_pixel_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||
a_latent_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context = context[0]
|
||||
a_context = context[1]
|
||||
v_timestep = timestep[0]
|
||||
a_timestep = timestep[1]
|
||||
v_pe, av_cross_video_freq_cis = pe[0]
|
||||
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||
|
||||
(
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(
|
||||
args["img"],
|
||||
v_context=args["v_context"],
|
||||
a_context=args["a_context"],
|
||||
attention_mask=args["attention_mask"],
|
||||
v_timestep=args["v_timestep"],
|
||||
a_timestep=args["a_timestep"],
|
||||
v_pe=args["v_pe"],
|
||||
a_pe=args["a_pe"],
|
||||
v_cross_pe=args["v_cross_pe"],
|
||||
a_cross_pe=args["a_cross_pe"],
|
||||
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": (vx, ax),
|
||||
"v_context": v_context,
|
||||
"a_context": a_context,
|
||||
"attention_mask": attention_mask,
|
||||
"v_timestep": v_timestep,
|
||||
"a_timestep": a_timestep,
|
||||
"v_pe": v_pe,
|
||||
"a_pe": a_pe,
|
||||
"v_cross_pe": av_cross_video_freq_cis,
|
||||
"a_cross_pe": av_cross_audio_freq_cis,
|
||||
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
vx, ax = out["img"]
|
||||
else:
|
||||
vx, ax = block(
|
||||
(vx, ax),
|
||||
v_context=v_context,
|
||||
a_context=a_context,
|
||||
attention_mask=attention_mask,
|
||||
v_timestep=v_timestep,
|
||||
a_timestep=a_timestep,
|
||||
v_pe=v_pe,
|
||||
a_pe=a_pe,
|
||||
v_cross_pe=av_cross_video_freq_cis,
|
||||
a_cross_pe=av_cross_audio_freq_cis,
|
||||
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
a_scale_shift_values = (
|
||||
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||
+ a_embedded_timestep[:, :, None]
|
||||
)
|
||||
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||
|
||||
ax = self.audio_norm_out(ax)
|
||||
ax = ax * (1 + a_scale) + a_shift
|
||||
ax = self.audio_proj_out(ax)
|
||||
|
||||
# Unpatchify audio
|
||||
ax = self.a_patchifier.unpatchify(
|
||||
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||
)
|
||||
|
||||
# Recombine audio and video
|
||||
original_shape = kwargs.get("av_orig_shape")
|
||||
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask=None,
|
||||
frame_rate=25,
|
||||
transformer_options={},
|
||||
keyframe_idxs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass for LTXAV model.
|
||||
|
||||
Args:
|
||||
x: Combined audio-video input tensor
|
||||
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments including audio_length
|
||||
|
||||
Returns:
|
||||
Combined audio-video output tensor
|
||||
"""
|
||||
# Handle timestep format
|
||||
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||
v_timestep, a_timestep = timestep
|
||||
kwargs["a_timestep"] = a_timestep
|
||||
timestep = v_timestep
|
||||
else:
|
||||
kwargs["a_timestep"] = timestep
|
||||
|
||||
# Call parent forward method
|
||||
return super().forward(
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask,
|
||||
frame_rate,
|
||||
transformer_options,
|
||||
keyframe_idxs,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,305 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import torch
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
generate_freq_grid_np,
|
||||
interleaved_freqs_cis,
|
||||
split_freqs_cis,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BasicTransformerBlock1D(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||
qk_norm (`str`, *optional*, defaults to None):
|
||||
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
glu=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
num_attention_heads=30,
|
||||
num_layers=2,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
split_rope=False,
|
||||
double_precision_rope=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_rope = split_rope
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock1D(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||
for i in range(1)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
def precompute_freqs(self, indices_grid, spacing):
|
||||
source_dtype = indices_grid.dtype
|
||||
dtype = (
|
||||
torch.float32
|
||||
if source_dtype in (torch.bfloat16, torch.float16)
|
||||
else source_dtype
|
||||
)
|
||||
|
||||
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||
indices = (
|
||||
generate_freq_grid_np(
|
||||
self.positional_embedding_theta,
|
||||
indices_grid.shape[1],
|
||||
self.inner_dim,
|
||||
)
|
||||
if self.double_precision_rope
|
||||
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||
).to(device=fractional_positions.device)
|
||||
|
||||
if spacing == "exp_2":
|
||||
freqs = (
|
||||
(indices * fractional_positions.unsqueeze(-1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
else:
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
def generate_freq_grid(self, spacing, dtype, device):
|
||||
dim = self.inner_dim
|
||||
theta = self.positional_embedding_theta
|
||||
n_pos_dims = 1
|
||||
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
if spacing == "exp":
|
||||
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||
indices = indices.to(dtype=dtype, device=device)
|
||||
elif spacing == "exp_2":
|
||||
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||
indices = indices.to(dtype=dtype)
|
||||
elif spacing == "linear":
|
||||
indices = torch.linspace(
|
||||
start, end, dim // n_elem, device=device, dtype=dtype
|
||||
)
|
||||
elif spacing == "sqrt":
|
||||
indices = torch.linspace(
|
||||
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||
).sqrt()
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
if self.split_rope:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(
|
||||
freqs, pad_size, self.num_attention_heads
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
|
||||
if self.num_learnable_registers:
|
||||
num_registers_duplications = math.ceil(
|
||||
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||
)
|
||||
learnable_registers = torch.tile(
|
||||
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||
)
|
||||
|
||||
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
|
||||
indices_grid = torch.arange(
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
# if self.output_scale is not None:
|
||||
# hidden_states = hidden_states / self.output_scale
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
@@ -1,292 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(
|
||||
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||
)
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class PixelShuffleND(nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
|
||||
|
||||
class BlurDownsample(nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int):
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert stride >= 1 and isinstance(stride, int)
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
|
||||
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||
# x2d: (B, C, H, W)
|
||||
B, C, H, W = x2d.shape
|
||||
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||
x2d = F.conv2d(
|
||||
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||
)
|
||||
return x2d
|
||||
|
||||
if self.dims == 2:
|
||||
return _apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = _apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||
):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = nn.GroupNorm(32, channels)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class LatentUpsampler(nn.Module):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 512,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = nn.SiLU()
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=self.spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either spatial_upsample or temporal_upsample must be True"
|
||||
)
|
||||
|
||||
self.post_upsample_res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
x = x[:, :, 1:, :, :]
|
||||
else:
|
||||
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(
|
||||
in_channels=config.get("in_channels", 4),
|
||||
mid_channels=config.get("mid_channels", 128),
|
||||
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||
dims=config.get("dims", 2),
|
||||
spatial_upsample=config.get("spatial_upsample", True),
|
||||
temporal_upsample=config.get("temporal_upsample", False),
|
||||
spatial_scale=config.get("spatial_scale", 2.0),
|
||||
rational_resampler=config.get("rational_resampler", False),
|
||||
)
|
||||
|
||||
def config(self):
|
||||
return {
|
||||
"_class_name": "LatentUpsampler",
|
||||
"in_channels": self.in_channels,
|
||||
"mid_channels": self.mid_channels,
|
||||
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||
"dims": self.dims,
|
||||
"spatial_upsample": self.spatial_upsample,
|
||||
"temporal_upsample": self.temporal_upsample,
|
||||
"spatial_scale": self.spatial_scale,
|
||||
"rational_resampler": self.rational_resampler,
|
||||
}
|
||||
@@ -1,47 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
class LTXRopeType(str, Enum):
|
||||
INTERLEAVED = "interleaved"
|
||||
SPLIT = "split"
|
||||
|
||||
KEY = "rope_type"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.INTERLEAVED
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
class LTXFrequenciesPrecision(str, Enum):
|
||||
FLOAT32 = "float32"
|
||||
FLOAT64 = "float64"
|
||||
|
||||
KEY = "frequencies_precision"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.FLOAT32
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@@ -73,7 +40,9 @@ def get_timestep_embedding(
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
@@ -105,9 +74,7 @@ class TimestepEmbedding(nn.Module):
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -124,9 +91,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(
|
||||
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||
)
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
@@ -175,22 +140,12 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
size_emb_dim,
|
||||
use_additional_conditions: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
@@ -209,22 +164,15 @@ class AdaLayerNormSingle(nn.Module):
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||
):
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim,
|
||||
size_emb_dim=embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -238,7 +186,6 @@ class AdaLayerNormSingle(nn.Module):
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
@@ -246,24 +193,18 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||
):
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(
|
||||
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
@@ -282,28 +223,25 @@ class GELU_approx(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||
return (
|
||||
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
if split_pe else
|
||||
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
)
|
||||
|
||||
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
@@ -313,37 +251,9 @@ def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: r
|
||||
|
||||
return out
|
||||
|
||||
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
needs_reshape = False
|
||||
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||
B, H, T, _ = cos.shape
|
||||
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||
needs_reshape = True
|
||||
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||
first_half_input = split_input[..., :1, :]
|
||||
second_half_input = split_input[..., 1:, :]
|
||||
output = split_input * cos.unsqueeze(-2)
|
||||
first_half_output = output[..., :1, :]
|
||||
second_half_output = output[..., 1:, :]
|
||||
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||
output = rearrange(output, "... d r -> ... (d r)")
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
@@ -359,11 +269,9 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@@ -374,7 +282,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@@ -384,495 +292,146 @@ class CrossAttention(nn.Module):
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
x.addcmul_(self.ff(y), gate_mlp)
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
fractional_positions = torch.stack(
|
||||
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||
axis=-1,
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=5)
|
||||
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||
theta = positional_embedding_theta
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(
|
||||
_log_base(start, theta),
|
||||
_log_base(end, theta),
|
||||
inner_dim // n_elem,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||
|
||||
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
inner_dim // n_elem,
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=torch.float32)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||
if use_middle_indices_grid:
|
||||
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = indices.to(device=fractional_positions.device)
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
def interleaved_freqs_cis(freqs, pad_size):
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq, sin_freq
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
B , T, half_HD = cos_freq.shape
|
||||
|
||||
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
|
||||
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for LTX models (Lightricks Transformer models).
|
||||
|
||||
This class defines the common interface and shared functionality for all LTX models,
|
||||
including LTXV (video) and LTXAV (audio-video) variants.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
cross_attention_dim: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
caption_channels: int,
|
||||
num_layers: int,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||
causal_temporal_positioning: bool = False,
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.dtype = dtype
|
||||
self.in_channels = in_channels
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.caption_channels = caption_channels
|
||||
self.num_layers = num_layers
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||
self.freq_grid_generator = (
|
||||
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||
else generate_freq_grid_pytorch
|
||||
)
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
# Initialize common components
|
||||
self._init_common_components(device, dtype)
|
||||
|
||||
# Initialize model-specific components
|
||||
self._init_model_components(device, dtype, **kwargs)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||
|
||||
# Initialize output components
|
||||
self._init_output_components(device, dtype)
|
||||
|
||||
def _init_common_components(self, device, dtype):
|
||||
"""Initialize components common to all LTX models
|
||||
- patchify_proj: Linear projection for patchifying input
|
||||
- adaln_single: AdaLN layer for timestep embedding
|
||||
- caption_projection: Linear projection for caption embedding
|
||||
"""
|
||||
self.patchify_proj = self.operations.Linear(
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
self,
|
||||
indices_grid,
|
||||
dim,
|
||||
out_dtype,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=False,
|
||||
num_attention_heads=32,
|
||||
):
|
||||
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
if split_mode:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
"""Prepare positional embeddings."""
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
pe = self._precompute_freqs_cis(
|
||||
fractional_coords,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
return pe
|
||||
|
||||
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||
"""Prepare attention mask."""
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
) * torch.finfo(x_dtype).max
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(
|
||||
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||
),
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Internal forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
input_dtype = x[0].dtype
|
||||
batch_size = x[0].shape[0]
|
||||
else:
|
||||
input_dtype = x.dtype
|
||||
batch_size = x.shape[0]
|
||||
# Process input
|
||||
merged_args = {**transformer_options, **kwargs}
|
||||
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
)
|
||||
|
||||
# Process output
|
||||
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||
return x
|
||||
|
||||
|
||||
class LTXVModel(LTXBaseModel):
|
||||
"""LTXV model for video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXV."""
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = self.operations.LayerNorm(
|
||||
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXV."""
|
||||
additional_args = {"orig_shape": list(x.shape)}
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
@@ -880,30 +439,44 @@ class LTXVModel(LTXBaseModel):
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
additional_args.update({"grid_mask": grid_mask})
|
||||
x = x[:, grid_mask, :]
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
return x, pixel_coords, additional_args
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
@@ -921,28 +494,16 @@ class LTXVModel(LTXBaseModel):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output for LTXV."""
|
||||
# Apply scale-shift modulation
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
full_x[:, grid_mask, :] = x
|
||||
x = full_x
|
||||
# Unpatchify to restore original dimensions
|
||||
orig_shape = kwargs["orig_shape"]
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
|
||||
@@ -21,23 +21,20 @@ def latent_to_pixel_coords(
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
shape = [1] * latent_coords.ndim
|
||||
shape[1] = -1
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int, start_end: bool=False):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
self.start_end = start_end
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
@@ -74,23 +71,11 @@ class Patchifier(ABC):
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||
|
||||
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_start = rearrange(
|
||||
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
if self.start_end:
|
||||
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_end = rearrange(
|
||||
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
|
||||
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||
else:
|
||||
latent_coords = latent_sample_coords_start
|
||||
return latent_coords
|
||||
|
||||
|
||||
@@ -130,61 +115,3 @@ class SymmetricPatchifier(Patchifier):
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
|
||||
class AudioPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int,
|
||||
sample_rate=16000,
|
||||
hop_length=160,
|
||||
audio_latent_downsample_factor=4,
|
||||
is_causal=True,
|
||||
start_end=False,
|
||||
shift = 0
|
||||
):
|
||||
super().__init__(patch_size, start_end=start_end)
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self.shift = shift
|
||||
|
||||
def copy_with_shift(self, shift):
|
||||
return AudioPatchifier(
|
||||
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||
self.is_causal, self.start_end, shift
|
||||
)
|
||||
|
||||
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||
if self.is_causal:
|
||||
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# audio_latents: (batch, channels, time, freq)
|
||||
b, _, t, _ = audio_latents.shape
|
||||
audio_latents = rearrange(
|
||||
audio_latents,
|
||||
"b c t f -> b t (c f)",
|
||||
)
|
||||
|
||||
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
if self.start_end:
|
||||
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||
else:
|
||||
audio_latents_timings = audio_latents_start_timings
|
||||
return audio_latents, audio_latents_timings
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||
# audio_latents: (batch, time, freq * channels)
|
||||
audio_latents = rearrange(
|
||||
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||
)
|
||||
return audio_latents
|
||||
|
||||
@@ -1,286 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils as utils
|
||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioVAEComponentConfig:
|
||||
"""Container for model component configuration extracted from metadata."""
|
||||
|
||||
autoencoder: dict
|
||||
vocoder: dict
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||
|
||||
raw_config = metadata["config"]
|
||||
if isinstance(raw_config, str):
|
||||
parsed_config = json.loads(raw_config)
|
||||
else:
|
||||
parsed_config = raw_config
|
||||
|
||||
audio_config = parsed_config.get("audio_vae")
|
||||
vocoder_config = parsed_config.get("vocoder")
|
||||
|
||||
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||
|
||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||
|
||||
|
||||
class ModelDeviceManager:
|
||||
"""Manages device placement and GPU residency for the composed model."""
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||
|
||||
def ensure_model_loaded(self) -> None:
|
||||
comfy.model_management.free_memory(
|
||||
self.patcher.model_size(),
|
||||
self.patcher.load_device,
|
||||
)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
|
||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(self.patcher.load_device)
|
||||
|
||||
@property
|
||||
def load_device(self):
|
||||
return self.patcher.load_device
|
||||
|
||||
|
||||
class AudioLatentNormalizer:
|
||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||
|
||||
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||
self.patchifier = patchfier
|
||||
self.statistics = statistics_processor
|
||||
|
||||
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
normalized = self.statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||
|
||||
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
denormalized = self.statistics.un_normalize(patched)
|
||||
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||
|
||||
|
||||
class AudioPreprocessor:
|
||||
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||
|
||||
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.mel_bins = mel_bins
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.n_fft = n_fft
|
||||
|
||||
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||
if source_rate == self.target_sample_rate:
|
||||
return waveform
|
||||
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def normalize_amplitude(
|
||||
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||
peak = torch.max(torch.abs(waveform)) + eps
|
||||
scale = peak.clamp(max=max_amplitude) / peak
|
||||
return waveform * scale
|
||||
|
||||
def waveform_to_mel(
|
||||
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||
) -> torch.Tensor:
|
||||
waveform = self.resample(waveform, waveform_sample_rate)
|
||||
waveform = self.normalize_amplitude(waveform)
|
||||
|
||||
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.n_fft,
|
||||
hop_length=self.mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=self.target_sample_rate / 2.0,
|
||||
n_mels=self.mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
|
||||
mel = mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class AudioVAE(torch.nn.Module):
|
||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||
|
||||
def __init__(self, state_dict: dict, metadata: dict):
|
||||
super().__init__()
|
||||
|
||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||
|
||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=autoencoder_config["sampling_rate"],
|
||||
hop_length=autoencoder_config["mel_hop_length"],
|
||||
is_causal=autoencoder_config["is_causal"],
|
||||
),
|
||||
self.autoencoder.per_channel_statistics,
|
||||
)
|
||||
|
||||
self.preprocessor = AudioPreprocessor(
|
||||
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||
mel_bins=autoencoder_config["mel_bins"],
|
||||
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
self.device_manager = ModelDeviceManager(self)
|
||||
|
||||
def encode(self, audio: dict) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
waveform = audio["waveform"]
|
||||
waveform_sample_rate = audio["sample_rate"]
|
||||
input_device = waveform.device
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
)
|
||||
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
posterior = DiagonalGaussianDistribution(latents)
|
||||
latent_mode = posterior.mode()
|
||||
|
||||
normalized = self.normalizer.normalize(latent_mode)
|
||||
return normalized.to(input_device)
|
||||
|
||||
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
latents = self.device_manager.move_to_load_device(latents)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||
|
||||
waveform = self.run_vocoder(mel_spec)
|
||||
return self.device_manager.move_to_load_device(waveform)
|
||||
|
||||
def target_shape_from_latents(self, latents_shape):
|
||||
batch, _, time, _ = latents_shape
|
||||
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||
return (
|
||||
batch,
|
||||
self.autoencoder.decoder.out_ch,
|
||||
target_length,
|
||||
self.autoencoder.mel_bins,
|
||||
)
|
||||
|
||||
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||
|
||||
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
audio_channels = self.autoencoder.decoder.out_ch
|
||||
vocoder_input = mel_spec.transpose(2, 3)
|
||||
|
||||
if audio_channels == 1:
|
||||
vocoder_input = vocoder_input.squeeze(1)
|
||||
elif audio_channels != 2:
|
||||
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||
|
||||
return self.vocoder(vocoder_input)
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return int(self.autoencoder.sampling_rate)
|
||||
|
||||
@property
|
||||
def mel_hop_length(self) -> int:
|
||||
return int(self.autoencoder.mel_hop_length)
|
||||
|
||||
@property
|
||||
def mel_bins(self) -> int:
|
||||
return int(self.autoencoder.mel_bins)
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return int(self.autoencoder.decoder.z_channels)
|
||||
|
||||
@property
|
||||
def latent_frequency_bins(self) -> int:
|
||||
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||
|
||||
@property
|
||||
def latents_per_second(self) -> float:
|
||||
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
@property
|
||||
def output_sample_rate(self) -> int:
|
||||
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||
if output_rate is not None:
|
||||
return int(output_rate)
|
||||
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||
if upsample_factor is None:
|
||||
raise AttributeError(
|
||||
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||
)
|
||||
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.device_manager.patcher.model_size()
|
||||
@@ -1,909 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from .pixel_norm import PixelNorm
|
||||
import comfy.ops
|
||||
import logging
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class StringConvertibleEnum(Enum):
|
||||
"""
|
||||
Base enum class that provides string-to-enum conversion functionality.
|
||||
|
||||
This mixin adds a str_to_enum() class method that handles conversion from
|
||||
strings, None, or existing enum instances with case-insensitive matching.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def str_to_enum(cls, value):
|
||||
"""
|
||||
Convert a string, enum instance, or None to the appropriate enum member.
|
||||
|
||||
Args:
|
||||
value: Can be an enum instance of this class, a string, or None
|
||||
|
||||
Returns:
|
||||
Enum member of this class
|
||||
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to a valid enum member
|
||||
"""
|
||||
# Already an enum instance of this class
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
|
||||
# None maps to NONE member if it exists
|
||||
if value is None:
|
||||
if hasattr(cls, "NONE"):
|
||||
return cls.NONE
|
||||
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||
|
||||
# String conversion (case-insensitive)
|
||||
if isinstance(value, str):
|
||||
value_lower = value.lower()
|
||||
|
||||
# Try to match against enum values
|
||||
for member in cls:
|
||||
# Handle members with None values
|
||||
if member.value is None:
|
||||
if value_lower == "none":
|
||||
return member
|
||||
# Handle members with string values
|
||||
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||
return member
|
||||
|
||||
# Build helpful error message with valid values
|
||||
valid_values = []
|
||||
for member in cls:
|
||||
if member.value is None:
|
||||
valid_values.append("none")
|
||||
elif isinstance(member.value, str):
|
||||
valid_values.append(member.value)
|
||||
|
||||
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||
|
||||
raise ValueError(
|
||||
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||
f"Expected string, None, or {cls.__name__} instance."
|
||||
)
|
||||
|
||||
|
||||
class AttentionType(StringConvertibleEnum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CausalityAxis(StringConvertibleEnum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||
if normtype == "group":
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif normtype == "pixel":
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution.
|
||||
|
||||
This layer ensures that the output at time `t` only depends on inputs
|
||||
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||
to the time dimension (width) before the convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
# Ensure kernel_size and dilation are tuples
|
||||
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||
dilation = nn.modules.utils._pair(dilation)
|
||||
|
||||
# Calculate padding dimensions
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.HEIGHT:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
# The internal convolution layer uses no padding, as we handle it manually
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Apply causal padding before convolution
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: Optional[CausalityAxis] = None,
|
||||
):
|
||||
"""
|
||||
Create a 2D convolution layer that can be either causal or non-causal.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels
|
||||
kernel_size: Size of the convolution kernel
|
||||
stride: Convolution stride
|
||||
padding: Padding (if None, will be calculated based on causal flag)
|
||||
dilation: Dilation rate
|
||||
groups: Number of groups for grouped convolution
|
||||
bias: Whether to use bias
|
||||
causality_axis: Dimension along which to apply causality.
|
||||
|
||||
Returns:
|
||||
Either a regular Conv2d or CausalConv2d layer
|
||||
"""
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int):
|
||||
padding = kernel_size // 2
|
||||
else:
|
||||
padding = tuple(k // 2 for k in kernel_size)
|
||||
return ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||
# So the output elements rely on the following windows:
|
||||
# 0: [-,-,0]
|
||||
# 1: [-,0,0]
|
||||
# 2: [0,0,1]
|
||||
# 3: [0,1,1]
|
||||
# 4: [1,1,2]
|
||||
# 5: [1,2,2]
|
||||
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||
# while all other elements rely on two elements in the input.
|
||||
# So we can drop the first element to undo the padding (rather than the last element).
|
||||
# This is a no-op for non-causal convolutions.
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pass # x remains unchanged
|
||||
case CausalityAxis.HEIGHT:
|
||||
x = x[:, :, 1:, :]
|
||||
case CausalityAxis.WIDTH:
|
||||
x = x[:, :, :, 1:]
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pass # x remains unchanged
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer that can use either a strided convolution
|
||||
or average pooling. Supports standard and causal padding for the
|
||||
convolutional mode.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pad = (0, 1, 0, 1)
|
||||
case CausalityAxis.WIDTH:
|
||||
pad = (2, 0, 0, 1)
|
||||
case CausalityAxis.HEIGHT:
|
||||
pad = (0, 1, 2, 0)
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pad = (1, 0, 0, 1)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
norm_type="group",
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, norm_type="group"):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w).contiguous()
|
||||
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w).contiguous()
|
||||
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||
# Convert string to enum if needed
|
||||
attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
if attn_type != AttentionType.NONE:
|
||||
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||
else:
|
||||
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||
|
||||
match attn_type:
|
||||
case AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
case AttentionType.NONE:
|
||||
return nn.Identity(in_channels)
|
||||
case AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.z_channels = z_channels
|
||||
self.double_z = double_z
|
||||
self.norm_type = norm_type
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = make_conv2d(
|
||||
in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the encoder.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||
|
||||
Returns:
|
||||
Encoded latent representation
|
||||
"""
|
||||
feature_maps = [self.conv_in(x)]
|
||||
|
||||
# Process each resolution level (from high to low resolution)
|
||||
for resolution_level in range(self.num_resolutions):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
# Apply ResNet block with optional timestep embedding
|
||||
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||
|
||||
# Apply attention if configured for this resolution level
|
||||
if len(self.down[resolution_level].attn) > 0:
|
||||
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||
|
||||
# Store processed features
|
||||
feature_maps.append(current_features)
|
||||
|
||||
# Downsample spatial dimensions (except at the final resolution level)
|
||||
if resolution_level != self.num_resolutions - 1:
|
||||
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||
feature_maps.append(downsampled_features)
|
||||
|
||||
# === MIDDLE PROCESSING PHASE ===
|
||||
# Take the lowest resolution features for middle processing
|
||||
bottleneck_features = feature_maps[-1]
|
||||
|
||||
# Apply first middle ResNet block
|
||||
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||
|
||||
# Apply middle attention block
|
||||
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||
|
||||
# Apply second middle ResNet block
|
||||
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||
|
||||
# === OUTPUT PHASE ===
|
||||
# Normalize the bottleneck features
|
||||
output_features = self.norm_out(bottleneck_features)
|
||||
|
||||
# Apply non-linearity (SiLU activation)
|
||||
output_features = self.non_linearity(output_features)
|
||||
|
||||
# Final convolution to produce latent representation
|
||||
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||
return self.conv_out(output_features)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = out_ch
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.norm_type = norm_type
|
||||
self.z_channels = z_channels
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||
"""
|
||||
Adjust output shape to match target dimensions for variable-length audio.
|
||||
|
||||
This function handles the common case where decoded audio spectrograms need to be
|
||||
resized to match a specific target shape.
|
||||
|
||||
Args:
|
||||
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||
|
||||
Returns:
|
||||
Tensor adjusted to match target_shape exactly
|
||||
"""
|
||||
# Current output shape: (batch, channels, time, frequency)
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
_, target_channels, target_time, target_freq = target_shape
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
# Step 3: Apply padding if needed
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||
0,
|
||||
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
# Step 4: Final safety crop to ensure exact target shape
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"ch": self.ch,
|
||||
"out_ch": self.out_ch,
|
||||
"ch_mult": self.ch_mult,
|
||||
"num_res_blocks": self.num_res_blocks,
|
||||
"in_channels": self.in_channels,
|
||||
"resolution": self.resolution,
|
||||
"z_channels": self.z_channels,
|
||||
}
|
||||
|
||||
def forward(self, latent_features, target_shape=None):
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
|
||||
Args:
|
||||
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||
If provided, output will be cropped/padded to match this shape
|
||||
|
||||
Returns:
|
||||
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||
"""
|
||||
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||
|
||||
# Transform latent features to decoder's internal feature dimension
|
||||
hidden_features = self.conv_in(latent_features)
|
||||
|
||||
# Middle processing
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
# Upsampling
|
||||
# Progressively increase spatial resolution from lowest to highest
|
||||
for resolution_level in reversed(range(self.num_resolutions)):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_index in range(self.num_res_blocks + 1):
|
||||
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||
|
||||
if len(self.up[resolution_level].attn) > 0:
|
||||
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||
|
||||
if resolution_level != 0:
|
||||
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||
|
||||
# Output
|
||||
if self.give_pre_end:
|
||||
# Return intermediate features before final processing (for debugging/analysis)
|
||||
decoded_output = hidden_features
|
||||
else:
|
||||
# Standard output path: normalize, activate, and convert to output channels
|
||||
# Final normalization layer
|
||||
hidden_features = self.norm_out(hidden_features)
|
||||
|
||||
# Apply SiLU (Swish) activation function
|
||||
hidden_features = self.non_linearity(hidden_features)
|
||||
|
||||
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||
decoded_output = self.conv_out(hidden_features)
|
||||
|
||||
# Optional tanh activation to bound output values to [-1, 1] range
|
||||
if self.tanh_out:
|
||||
decoded_output = torch.tanh(decoded_output)
|
||||
|
||||
# Adjust shape for audio data
|
||||
if target_shape is not None:
|
||||
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||
|
||||
|
||||
class CausalAudioAutoencoder(nn.Module):
|
||||
def __init__(self, config=None):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
self.encoder = Encoder(**encoder_config)
|
||||
self.decoder = Decoder(**decoder_config)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"mel_bins": self.mel_bins,
|
||||
"mel_hop_length": self.mel_hop_length,
|
||||
"n_fft": self.n_fft,
|
||||
"causality_axis": self.causality_axis.value,
|
||||
"is_causal": self.is_causal,
|
||||
}
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x, target_shape=None):
|
||||
return self.decoder(x, target_shape=target_shape)
|
||||
@@ -1,213 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(Vocoder, self).__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
x: Input spectrogram tensor. Can be:
|
||||
- 3D: (batch_size, channels, time_steps) for mono
|
||||
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||
|
||||
Returns:
|
||||
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
if x.dim() == 4: # stereo
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
@@ -1,160 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .model import JointTransformerBlock
|
||||
|
||||
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
block_id=0,
|
||||
operation_settings=None,
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c_skip, c
|
||||
|
||||
class ZImage_Control(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
i,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=i,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for i in range(self.n_control_layers)
|
||||
]
|
||||
)
|
||||
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
pH = pW = patch_size
|
||||
B, C, H, W = control_context.shape
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
@@ -11,7 +11,6 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
@@ -22,10 +21,6 @@ def modulate(x, scale):
|
||||
# Core NextDiT Model #
|
||||
#############################################################################
|
||||
|
||||
def clamp_fp16(x):
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
"""Multi-head attention module."""
|
||||
@@ -36,7 +31,6 @@ class JointAttention(nn.Module):
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
out_bias: bool = False,
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
@@ -65,7 +59,7 @@ class JointAttention(nn.Module):
|
||||
self.out = operation_settings.get("operations").Linear(
|
||||
n_heads * self.head_dim,
|
||||
dim,
|
||||
bias=out_bias,
|
||||
bias=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
@@ -76,6 +70,35 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
|
||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x_in.shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -111,7 +134,8 @@ class JointAttention(nn.Module):
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
@@ -173,7 +197,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return clamp_fp16(F.silu(x1) * x3)
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
@@ -191,8 +215,6 @@ class JointTransformerBlock(nn.Module):
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
z_image_modulation=False,
|
||||
attn_out_bias=False,
|
||||
operation_settings={},
|
||||
) -> None:
|
||||
"""
|
||||
@@ -213,10 +235,10 @@ class JointTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
operation_settings=operation_settings,
|
||||
@@ -230,27 +252,16 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
if z_image_modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 256),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -277,27 +288,27 @@ class JointTransformerBlock(nn.Module):
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
clamp_fp16(self.feed_forward(
|
||||
self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
x = x + self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
@@ -312,7 +323,7 @@ class FinalLayer(nn.Module):
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
||||
super().__init__()
|
||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||
hidden_size,
|
||||
@@ -329,15 +340,10 @@ class FinalLayer(nn.Module):
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
if z_image_modulation:
|
||||
min_mod = 256
|
||||
else:
|
||||
min_mod = 1024
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(hidden_size, min_mod),
|
||||
min(hidden_size, 1024),
|
||||
hidden_size,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
@@ -367,17 +373,12 @@ class NextDiT(nn.Module):
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: Optional[int] = None,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: float = 4.0,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
rope_theta=10000.0,
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -389,8 +390,6 @@ class NextDiT(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.time_scale = time_scale
|
||||
self.pad_tokens_multiple = pad_tokens_multiple
|
||||
|
||||
self.x_embedder = operation_settings.get("operations").Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
@@ -412,7 +411,6 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=z_image_modulation,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
@@ -436,7 +434,7 @@ class NextDiT(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
@@ -448,31 +446,6 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@@ -484,25 +457,18 @@ class NextDiT(nn.Module):
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
z_image_modulation=z_image_modulation,
|
||||
attn_out_bias=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
@@ -537,63 +503,96 @@ class NextDiT(nn.Module):
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
dtype = x[0].dtype
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
if cap_mask is not None:
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
else:
|
||||
l_effective_cap_len = [num_tokens] * bsz
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
max_seq_len = max(
|
||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
padded_img_mask = None
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
@@ -604,7 +603,7 @@ class NextDiT(nn.Module):
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@@ -616,41 +615,21 @@ class NextDiT(nn.Module):
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
return -img
|
||||
return -x
|
||||
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
import comfy.model_management
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
'''
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
@@ -1,157 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import comfy.model_management
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
else:
|
||||
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/core.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def sinc(x: torch.Tensor):
|
||||
"""
|
||||
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||
"""
|
||||
return torch.where(x == 0,
|
||||
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x)
|
||||
|
||||
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = (kernel_size % 2 == 0)
|
||||
half_size = kernel_size // 2
|
||||
|
||||
#For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.:
|
||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||
else:
|
||||
beta = 0.
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||
if even:
|
||||
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||
# of the constant component in the input signal.
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = 'replicate',
|
||||
kernel_size: int = 12):
|
||||
# kernel_size should be even number for stylegan3 setup,
|
||||
# in this implementation, odd number is also possible.
|
||||
super().__init__()
|
||||
if cutoff < -0.:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = (kernel_size % 2 == 0)
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
#input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||
mode=self.padding_mode)
|
||||
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||
stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.lowpass(x)
|
||||
|
||||
return xx
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
# x: [B,C,T]
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
||||
@@ -1,156 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
from .vae import VAE_16k
|
||||
from .bigvgan import BigVGANVocoder
|
||||
import logging
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes, norm_fn):
|
||||
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||
return output
|
||||
|
||||
class MelConverter(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sampling_rate: float,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
norm_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
# mel = librosa_mel_fn(sr=self.sampling_rate,
|
||||
# n_fft=self.n_fft,
|
||||
# n_mels=self.num_mels,
|
||||
# fmin=self.fmin,
|
||||
# fmax=self.fmax)
|
||||
# mel_basis = torch.from_numpy(mel).float()
|
||||
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
|
||||
hann_window = torch.hann_window(self.win_size)
|
||||
|
||||
self.register_buffer('mel_basis', mel_basis)
|
||||
self.register_buffer('hann_window', hann_window)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.mel_basis.device
|
||||
|
||||
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
|
||||
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
|
||||
|
||||
waveform = torch.nn.functional.pad(
|
||||
waveform.unsqueeze(1),
|
||||
[int((self.n_fft - self.hop_size) / 2),
|
||||
int((self.n_fft - self.hop_size) / 2)],
|
||||
mode='reflect')
|
||||
waveform = waveform.squeeze(1)
|
||||
|
||||
spec = torch.stft(waveform,
|
||||
self.n_fft,
|
||||
hop_length=self.hop_size,
|
||||
win_length=self.win_size,
|
||||
window=self.hann_window,
|
||||
center=center,
|
||||
pad_mode='reflect',
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True)
|
||||
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
spec = torch.matmul(self.mel_basis, spec)
|
||||
spec = spectral_normalize_torch(spec, self.norm_fn)
|
||||
|
||||
return spec
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# ckpt_path: str,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode == "16k", "Only 16k mode is supported currently."
|
||||
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8_000,
|
||||
norm_fn=torch.log10)
|
||||
|
||||
self.vae = VAE_16k().eval()
|
||||
|
||||
bigvgan_config = {
|
||||
"resblock": "1",
|
||||
"num_mels": 80,
|
||||
"upsample_rates": [4, 4, 2, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
],
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": True,
|
||||
}
|
||||
|
||||
self.vocoder = BigVGANVocoder(
|
||||
bigvgan_config
|
||||
).eval()
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||
# x: (B * L)
|
||||
mel = self.mel_converter(x)
|
||||
dist = self.vae.encode(mel)
|
||||
|
||||
return dist
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, z):
|
||||
mel_decoded = self.vae.decode(z)
|
||||
audio = self.vocoder(mel_decoded)
|
||||
|
||||
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||
return audio
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio):
|
||||
audio = audio.mean(dim=1)
|
||||
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||
dist = self.encode_audio(audio)
|
||||
return dist.mean
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from . import activations
|
||||
from .alias_free_torch import Activation1d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||
super(AMPBlock1, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||
super(AMPBlock2, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BigVGANVocoder(torch.nn.Module):
|
||||
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
if isinstance(h, dict):
|
||||
h = SimpleNamespace(**h)
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# pre conv
|
||||
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
|
||||
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||
|
||||
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2)
|
||||
]))
|
||||
|
||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# pre conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# post conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
@@ -1,358 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
Upsample1D, nonlinearity)
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
DATA_MEAN_80D = [
|
||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
||||
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
||||
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
||||
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
||||
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
||||
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
||||
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
||||
]
|
||||
|
||||
DATA_STD_80D = [
|
||||
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
||||
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
||||
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
||||
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
||||
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
||||
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
||||
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
||||
]
|
||||
|
||||
DATA_MEAN_128D = [
|
||||
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
||||
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
||||
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
||||
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
||||
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
||||
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
||||
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
||||
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
||||
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
||||
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
||||
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
||||
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
||||
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
||||
]
|
||||
|
||||
DATA_STD_128D = [
|
||||
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
||||
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
||||
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
||||
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
||||
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
||||
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
||||
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
||||
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
||||
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
||||
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
||||
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
||||
]
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_dim: int,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_dim == 80:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
||||
elif data_dim == 128:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
||||
|
||||
self.data_mean = self.data_mean.view(1, -1, 1)
|
||||
self.data_std = self.data_std.view(1, -1, 1)
|
||||
|
||||
self.encoder = Encoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
self.decoder = Decoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
out_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
||||
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
pass
|
||||
|
||||
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
||||
if normalize:
|
||||
x = self.normalize(x)
|
||||
moments = self.encoder(x)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
if unnormalize:
|
||||
dec = self.unnormalize(dec)
|
||||
return dec
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
|
||||
|
||||
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
rng: Optional[torch.Generator] = None,
|
||||
normalize: bool = True,
|
||||
unnormalize: bool = True,
|
||||
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||
|
||||
posterior = self.encode(x, normalize=normalize)
|
||||
if sample_posterior:
|
||||
z = posterior.sample(rng)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, unnormalize=unnormalize)
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def remove_weight_norm(self):
|
||||
return self
|
||||
|
||||
|
||||
class Encoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
double_z: bool = True,
|
||||
kernel_size: int = 3,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = down_layers
|
||||
self.attn_layers = attn_layers
|
||||
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
# downsampling
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_layers):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = dim * in_ch_mult[i_level]
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_out,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level in down_layers:
|
||||
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in,
|
||||
2 * embed_dim if double_z else embed_dim,
|
||||
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_layers):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# end
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
kernel_size: int = 3,
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.ch = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = dim * ch_mult[self.num_layers - 1]
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level in self.down_layers:
|
||||
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||
return h
|
||||
|
||||
|
||||
def VAE_16k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
||||
|
||||
|
||||
def VAE_44k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
||||
|
||||
|
||||
def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '16k':
|
||||
return VAE_16k(**kwargs)
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return torch.nn.functional.silu(x) / 0.596
|
||||
|
||||
def mp_sum(a, b, t=0.5):
|
||||
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
|
||||
|
||||
def normalize(x, dim=None, eps=1e-4):
|
||||
if dim is None:
|
||||
dim = list(range(1, x.ndim))
|
||||
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
class ResnetBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
out_dim = in_dim if out_dim is None else out_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.use_norm = use_norm
|
||||
|
||||
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
else:
|
||||
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# pixel norm
|
||||
if self.use_norm:
|
||||
x = normalize(x, dim=1)
|
||||
|
||||
h = x
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class AttnBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, num_heads=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
|
||||
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
y = self.qkv(h)
|
||||
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
|
||||
q, k, v = normalize(y, dim=1).unbind(2)
|
||||
|
||||
h = self.optimized_attention(q, k, v)
|
||||
h = self.proj_out(h)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv1(x)
|
||||
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
@@ -9,8 +9,6 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
|
||||
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = False):
|
||||
@@ -181,21 +179,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if ddconfig.get("batch_norm_latent", False):
|
||||
self.bn_eps = 1e-4
|
||||
self.bn_momentum = 0.1
|
||||
self.ps = [2, 2]
|
||||
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
|
||||
eps=self.bn_eps,
|
||||
momentum=self.bn_momentum,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
)
|
||||
self.bn.eval()
|
||||
else:
|
||||
self.bn = None
|
||||
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
@@ -218,36 +201,11 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
|
||||
if self.bn is not None:
|
||||
z = rearrange(z,
|
||||
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
|
||||
z = torch.nn.functional.batch_norm(z,
|
||||
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
|
||||
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
|
||||
momentum=self.bn_momentum,
|
||||
eps=self.bn_eps)
|
||||
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.bn is not None:
|
||||
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
|
||||
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
|
||||
z = z * s + m
|
||||
z = rearrange(
|
||||
z,
|
||||
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
|
||||
@@ -30,13 +30,6 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
@@ -524,7 +517,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
@@ -549,8 +541,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
if exception_fallback:
|
||||
if tensor_layout == "NHD":
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
@@ -570,93 +560,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if (q.device.type != "cuda" or
|
||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||
mask is not None):
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
B, H, L, D = q.shape
|
||||
if H != heads:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
B, N, inner_dim = q.shape
|
||||
if inner_dim % heads != 0:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
dim_head = inner_dim // heads
|
||||
|
||||
if dim_head >= 256 or N <= 1024:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||
except Exception as e:
|
||||
exception_fallback = True
|
||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||
|
||||
if exception_fallback:
|
||||
if not skip_reshape:
|
||||
del q_s, k_s, v_s
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
if not skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
pass
|
||||
else:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
@@ -744,8 +647,6 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
|
||||
@@ -211,14 +211,12 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
|
||||
@@ -13,12 +13,6 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
@@ -49,37 +43,6 @@ def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class CarriedConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if isinstance(op, CarriedConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class VideoConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
@@ -126,24 +89,29 @@ class Upsample(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x):
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
results = []
|
||||
if conv_carry_in is None:
|
||||
first = x[:, :, :1, :, :]
|
||||
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||
x = x[:, :, 1:, :, :]
|
||||
if x.shape[2] > 0:
|
||||
results.append(interpolate_up(x, scale_factor))
|
||||
x = torch_cat_if_needed(results, dim=2)
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -159,20 +127,17 @@ class Downsample(nn.Module):
|
||||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if isinstance(self.conv, CarriedConv3d):
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
elif x.ndim == 4:
|
||||
if x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
x = self.conv(x)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
@@ -218,23 +183,23 @@ class ResnetBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x, temb=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = [ self.swish(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
h = self.swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
h = [ self.dropout(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
@@ -314,7 +279,6 @@ def pytorch_attention(q, k, v):
|
||||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
oom_fallback = False
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
@@ -325,8 +289,6 @@ def pytorch_attention(q, k, v):
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
oom_fallback = True
|
||||
if oom_fallback:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||
return out
|
||||
|
||||
@@ -394,8 +356,7 @@ class Model(nn.Module):
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@@ -549,22 +510,16 @@ class Encoder(nn.Module):
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
if not attn_resolutions:
|
||||
conv_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@@ -577,7 +532,6 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.time_compress = 1
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
@@ -604,15 +558,10 @@ class Encoder(nn.Module):
|
||||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
else:
|
||||
self.time_compress *= 2
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
if time_compress is not None:
|
||||
self.time_compress = time_compress
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
@@ -638,42 +587,15 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
if self.carried:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.time_compress:
|
||||
tc = self.time_compress
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
# downsampling
|
||||
x1 = [ x1 ]
|
||||
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.down[i_level].attn[i_block](h1)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = torch_cat_if_needed(out, dim=2)
|
||||
del out
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
@@ -682,15 +604,15 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = [ nonlinearity(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv_out)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, tanh_out=False, use_linear_attn=False,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
@@ -704,18 +626,12 @@ class Decoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||
conv_op = CarriedConv3d
|
||||
conv_out_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@@ -790,43 +706,29 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = conv_carry_causal_3d([z], self.conv_in)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
if self.carried:
|
||||
h = torch.split(h, 2, dim=2)
|
||||
else:
|
||||
h = [ h ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
# upsampling
|
||||
for i, h1 in enumerate(h):
|
||||
conv_carry_out = []
|
||||
if i == len(h) - 1:
|
||||
conv_carry_out = None
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||
if i_level != 0:
|
||||
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
h1 = self.norm_out(h1)
|
||||
h1 = [ nonlinearity(h1) ]
|
||||
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
if self.tanh_out:
|
||||
h1 = torch.tanh(h1)
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
return out
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
@@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
||||
@@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||
|
||||
@@ -10,7 +10,6 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@@ -61,7 +60,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
@@ -72,19 +71,9 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if self.use_additional_t_cond:
|
||||
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
def forward(self, timestep, hidden_states):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
@@ -145,34 +134,33 @@ class Attention(nn.Module):
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@@ -228,24 +216,9 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
||||
if timestep_zero_index is not None:
|
||||
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
||||
else:
|
||||
return torch.addcmul(y, gate, x)
|
||||
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
if timestep_zero_index is not None:
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
||||
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
||||
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
||||
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
||||
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
||||
else:
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -254,22 +227,17 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
|
||||
img_attn_output, txt_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
@@ -278,20 +246,16 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
||||
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
@@ -330,11 +294,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
default_ref_method="index",
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
use_additional_t_cond=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -345,14 +308,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.default_ref_method = default_ref_method
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
use_additional_t_cond=use_additional_t_cond,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
@@ -374,9 +335,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
if self.default_ref_method == "index_timestep_zero":
|
||||
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
@@ -386,33 +344,27 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
t_len = t
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
if t_len > 1:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||
else:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@@ -420,8 +372,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
additional_t_cond=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
@@ -433,24 +385,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||
negative_ref_method = ref_method == "negative_index"
|
||||
timestep_zero = ref_method == "index_timestep_zero"
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif negative_ref_method:
|
||||
index -= 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
@@ -465,35 +409,35 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
@@ -505,7 +449,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep_zero_index=timestep_zero_index,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
@@ -522,12 +465,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
||||
@@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if "target" not in config:
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
||||
@@ -232,13 +232,11 @@ class WanAttentionBlock(nn.Module):
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
@@ -568,10 +566,7 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -592,7 +587,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@@ -605,22 +600,10 @@ class WanModel(torch.nn.Module):
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
@@ -646,7 +629,7 @@ class WanModel(torch.nn.Module):
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@@ -766,10 +749,7 @@ class VaceWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -868,10 +848,7 @@ class CameraWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -925,7 +902,7 @@ class MotionEncoder_tc(nn.Module):
|
||||
def __init__(self,
|
||||
in_dim: int,
|
||||
hidden_dim: int,
|
||||
num_heads: int,
|
||||
num_heads=int,
|
||||
need_global=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
@@ -1335,19 +1312,16 @@ class WanModel_S2V(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
if audio_emb is not None:
|
||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||
# head
|
||||
@@ -1381,7 +1355,7 @@ class WanT2VCrossAttentionGather(WanSelfAttention):
|
||||
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, -1, n * d)
|
||||
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@@ -1586,10 +1560,7 @@ class HumoWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@@ -523,10 +523,7 @@ class AnimateWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@@ -227,7 +227,6 @@ class Encoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
input_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@@ -246,7 +245,7 @@ class Encoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
@@ -332,7 +331,6 @@ class Decoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
output_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@@ -380,7 +378,7 @@ class Decoder3d(nn.Module):
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
@@ -451,7 +449,6 @@ class WanVAE(nn.Module):
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -463,54 +460,63 @@ class WanVAE(nn.Module):
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
@@ -657,51 +657,51 @@ class WanVAE(nn.Module):
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
self.clear_cache()
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
self.clear_cache()
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True,
|
||||
)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
@@ -715,3 +715,12 @@ class WanVAE(nn.Module):
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
@@ -313,23 +313,6 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@@ -48,7 +47,6 @@ import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -136,11 +134,10 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model.eval()
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
@@ -199,14 +196,8 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
return timestep
|
||||
@@ -331,6 +322,10 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@@ -674,6 +669,7 @@ class Lotus(BaseModel):
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@@ -702,6 +698,7 @@ class StableCascade_C(BaseModel):
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@@ -888,13 +885,12 @@ class Flux(BaseModel):
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
mask_ref_size = kwargs.get("attention_mask_img_shape", None)
|
||||
if mask_ref_size is not None:
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
|
||||
guidance = kwargs.get("guidance", 3.5)
|
||||
if guidance is not None:
|
||||
@@ -916,19 +912,9 @@ class Flux(BaseModel):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
target_text_len = 512
|
||||
if cross_attn.shape[1] < target_text_len:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@@ -947,7 +933,7 @@ class GenmoMochi(BaseModel):
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -978,60 +964,6 @@ class LTXV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class LTXAV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
|
||||
audio_denoise_mask = None
|
||||
if denoise_mask is not None and "latent_shapes" in kwargs:
|
||||
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
|
||||
if len(denoise_mask) > 1:
|
||||
audio_denoise_mask = denoise_mask[1]
|
||||
denoise_mask = denoise_mask[0]
|
||||
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
|
||||
if audio_denoise_mask is not None:
|
||||
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
|
||||
|
||||
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
latent_shapes = kwargs.get("latent_shapes", None)
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
v_timestep = timestep
|
||||
a_timestep = timestep
|
||||
|
||||
if denoise_mask is not None:
|
||||
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||
if audio_denoise_mask is not None:
|
||||
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
|
||||
|
||||
return v_timestep, a_timestep
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
@@ -1158,17 +1090,9 @@ class Lumina2(BaseModel):
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
@@ -1599,140 +1523,3 @@ class HunyuanImage21Refiner(HunyuanImage21):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||
if conditioning_byt5small is not None:
|
||||
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
|
||||
|
||||
return out
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
|
||||
if noise_augmentation > 0:
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(kwargs.get("seed", 0) - 10)
|
||||
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
|
||||
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
|
||||
else:
|
||||
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
class Kandinsky5(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||
if time_dim_replace is not None:
|
||||
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||
|
||||
return out
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -172,73 +172,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
|
||||
# HunyuanVideo 1.5
|
||||
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_cond_type_embedding"] = True
|
||||
else:
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["meanflow_sum"] = True
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
dit_config["axes_dim"] = [32, 32, 32, 32]
|
||||
dit_config["num_heads"] = 48
|
||||
dit_config["mlp_ratio"] = 3.0
|
||||
dit_config["theta"] = 2000
|
||||
dit_config["out_channels"] = 128
|
||||
dit_config["global_modulation"] = True
|
||||
dit_config["mlp_silu_act"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["ops_bias"] = False
|
||||
dit_config["default_ref_method"] = "index"
|
||||
dit_config["ref_index_scale"] = 10.0
|
||||
dit_config["txt_ids_dims"] = [3]
|
||||
patch_size = 1
|
||||
else:
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["txt_ids_dims"] = []
|
||||
patch_size = 2
|
||||
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["context_in_dim"] = 4096
|
||||
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
w = state_dict[in_key]
|
||||
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
|
||||
txt_in_key = "{}txt_in.weight".format(key_prefix)
|
||||
if txt_in_key in state_dict_keys:
|
||||
w = state_dict[txt_in_key]
|
||||
dit_config["context_in_dim"] = w.shape[1]
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
@@ -256,20 +213,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_tile_size"] = 32
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
@@ -305,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||
dit_config["image_model"] = "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
@@ -416,35 +364,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
|
||||
dit_config["dim"] = w.shape[0]
|
||||
dit_config["cap_feat_dim"] = w.shape[1]
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["dim"] = 2304
|
||||
dit_config["cap_feat_dim"] = 2304
|
||||
dit_config["n_layers"] = 26
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["qk_norm"] = True
|
||||
|
||||
if dit_config["dim"] == 2304: # Original Lumina 2
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
dit_config["axes_dims"] = [32, 48, 48]
|
||||
dit_config["axes_lens"] = [1536, 512, 512]
|
||||
dit_config["rope_theta"] = 256.0
|
||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||
dit_config["z_image_modulation"] = True
|
||||
dit_config["time_scale"] = 1000.0
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
@@ -619,29 +546,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||
dit_config["use_additional_t_cond"] = True
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
if model_dim == 2560: # lite image
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
@@ -786,11 +690,16 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
if quant_config:
|
||||
model_config.quant_config = quant_config
|
||||
logging.info("Detected mixed precision quantization")
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -90,7 +89,6 @@ if args.deterministic:
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
@@ -332,23 +330,13 @@ except:
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
@@ -356,11 +344,11 @@ try:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
|
||||
except:
|
||||
@@ -382,9 +370,6 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
@@ -456,7 +441,7 @@ def module_size(module):
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
@@ -507,7 +492,6 @@ class LoadedModel:
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
@@ -661,9 +645,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
current_loaded_models.pop(i).model.detach(unpatch_all=False)
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
@@ -692,11 +674,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
lowvram_model_memory = 0.1
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 0.1
|
||||
@@ -944,7 +923,11 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
if d == torch.float16 and should_use_fp16(device):
|
||||
return d
|
||||
|
||||
if d == torch.bfloat16 and should_use_bf16(device):
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
# also a problem on RDNA4 except fp32 is also slow there.
|
||||
# This is due to large bf16 convolutions being extremely slow.
|
||||
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
return torch.float32
|
||||
@@ -1006,6 +989,12 @@ def device_supports_non_blocking(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
@@ -1015,72 +1004,54 @@ def force_channels_last():
|
||||
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia and AMD
|
||||
if is_nvidia() or is_amd():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
NUM_STREAMS = 0
|
||||
|
||||
if NUM_STREAMS > 0:
|
||||
NUM_STREAMS = 1
|
||||
if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
def current_stream(device):
|
||||
if device is None:
|
||||
return None
|
||||
if is_device_cuda(device):
|
||||
return torch.cuda.current_stream()
|
||||
elif is_device_xpu(device):
|
||||
return torch.xpu.current_stream()
|
||||
else:
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS == 0:
|
||||
return None
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
if NUM_STREAMS <= 1:
|
||||
return None
|
||||
|
||||
if device in STREAMS:
|
||||
ss = STREAMS[device]
|
||||
#Sync the oldest stream in the queue with the current
|
||||
ss[stream_counter].wait_stream(current_stream(device))
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
elif is_device_xpu(device):
|
||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return ss[stream_counter]
|
||||
return s
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
s1 = torch.cuda.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.cuda.stream
|
||||
ss.append(s1)
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
s1 = torch.xpu.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.xpu.stream
|
||||
ss.append(s1)
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
if stream is None or current_stream(device) is None:
|
||||
if stream is None:
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.current_stream().wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
@@ -1088,19 +1059,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
with stream:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
with stream:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
@@ -1112,99 +1076,6 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
if tensor.is_pinned():
|
||||
#NOTE: Cuda does detect when a tensor is already pinned and would
|
||||
#error below, but there are proven cases where this also queues an error
|
||||
#on the GPU async. So dont trust the CUDA API and guard here
|
||||
return False
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
size = tensor.nbytes
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
else:
|
||||
logging.warning("Pin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
size = tensor.nbytes
|
||||
|
||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||
if size_stored is None:
|
||||
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
||||
return False
|
||||
|
||||
if size != size_stored:
|
||||
logging.warning("Size of pinned tensor changed")
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
@@ -1457,7 +1328,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
@@ -1504,16 +1375,6 @@ def supports_fp8_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_nvfp4_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
@@ -1521,20 +1382,6 @@ def extended_fp16_support():
|
||||
|
||||
return True
|
||||
|
||||
LORA_COMPUTE_DTYPES = {}
|
||||
def lora_compute_dtype(device):
|
||||
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
|
||||
if should_use_fp16(device):
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@@ -1552,10 +1399,6 @@ def soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@@ -35,7 +35,6 @@ import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
@@ -124,26 +123,16 @@ def move_weight_functions(m, device):
|
||||
return memory
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
def __init__(self, key, patches):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
intermediate_dtype = weight.dtype
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
|
||||
if model_dtype is None:
|
||||
model_dtype = weight.dtype
|
||||
|
||||
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
@@ -228,13 +217,13 @@ class ModelPatcher:
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
self.pinned = set()
|
||||
|
||||
self.attachments: dict[str] = {}
|
||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||
@@ -266,18 +255,12 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||
self.model.current_weight_patches_uuid = None
|
||||
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@@ -285,7 +268,7 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -297,7 +280,6 @@ class ModelPatcher:
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
n.pinned = self.pinned
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
@@ -454,22 +436,6 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
rope_options["scale_y"] = scale_y
|
||||
rope_options["scale_t"] = scale_t
|
||||
|
||||
rope_options["shift_x"] = shift_x
|
||||
rope_options["shift_y"] = shift_y
|
||||
rope_options["shift_t"] = shift_t
|
||||
|
||||
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@@ -621,11 +587,10 @@ class ModelPatcher:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(temp_dtype, copy=True)
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
@@ -639,21 +604,6 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if comfy.model_management.pin_memory(weight):
|
||||
self.pinned.add(key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@@ -666,22 +616,7 @@ class ModelPatcher:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
def check_module_offload_mem(key):
|
||||
if key in self.patches:
|
||||
return low_vram_patch_estimate_vram(self.model, key)
|
||||
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if model_dtype is None or weight is None:
|
||||
return 0
|
||||
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||
return weight.numel() * model_dtype.itemsize
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@@ -690,30 +625,25 @@ class ModelPatcher:
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
lowvram_mem_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
offloaded = []
|
||||
offload_buffer = 0
|
||||
loading.sort(reverse=True)
|
||||
for i, x in enumerate(loading):
|
||||
module_offload_mem, module_mem, n, m, params = x
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
|
||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if not lowvram_fits:
|
||||
offload_buffer = potential_offload
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
@@ -727,28 +657,23 @@ class ModelPatcher:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||
patch_counter += 1
|
||||
|
||||
cast_weight = True
|
||||
offloaded.append((module_mem, n, m, params))
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or lowvram_fits:
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
else:
|
||||
offload_buffer = potential_offload
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@@ -772,11 +697,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
torch.cuda.synchronize()
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@@ -784,17 +705,11 @@ class ModelPatcher:
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
for x in offloaded:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
@@ -803,7 +718,6 @@ class ModelPatcher:
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
@@ -832,7 +746,6 @@ class ModelPatcher:
|
||||
self.eject_model()
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
@@ -857,7 +770,6 @@ class ModelPatcher:
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
@@ -869,25 +781,20 @@ class ModelPatcher:
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
|
||||
offload_buffer = self.model.model_offload_buffer_memory
|
||||
if len(unload_list) > 0:
|
||||
NS = comfy.model_management.NUM_STREAMS
|
||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
module_offload_mem, module_mem, n, m, params = unload
|
||||
|
||||
potential_offload = module_offload_mem + sum(offload_weight_factor)
|
||||
module_mem = unload[0]
|
||||
n = unload[1]
|
||||
m = unload[2]
|
||||
params = unload[3]
|
||||
|
||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
@@ -918,40 +825,23 @@ class ModelPatcher:
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
if cast_weight:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
offload_weight_factor.append(module_mem)
|
||||
offload_weight_factor.pop(0)
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
@@ -964,9 +854,6 @@ class ModelPatcher:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
@@ -1354,6 +1241,5 @@ class ModelPatcher:
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
|
||||
@@ -21,23 +21,17 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
def reshape_sigma(sigma, noise_dim):
|
||||
if sigma.nelement() == 1:
|
||||
return sigma.view(())
|
||||
else:
|
||||
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
@@ -51,12 +45,12 @@ class EPS:
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class EDM(V_PREDICTION):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class CONST:
|
||||
@@ -64,15 +58,15 @@ class CONST:
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
sigma = reshape_sigma(sigma, latent.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class X0(EPS):
|
||||
@@ -86,16 +80,16 @@ class IMG_TO_IMG(X0):
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise * (1.0 - sigma)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * (1.0 - sigma) - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import torch
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors):
|
||||
self.tensors = list(tensors)
|
||||
self.is_nested = True
|
||||
|
||||
def _copy(self):
|
||||
return NestedTensor(self.tensors)
|
||||
|
||||
def apply_operation(self, other, operation):
|
||||
o = self._copy()
|
||||
if isinstance(other, NestedTensor):
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other.tensors[i])
|
||||
else:
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other)
|
||||
return o
|
||||
|
||||
def __add__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x + y)
|
||||
|
||||
def __sub__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x - y)
|
||||
|
||||
def __mul__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x * y)
|
||||
|
||||
# def __itruediv__(self, b):
|
||||
# return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __truediv__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
||||
|
||||
def unbind(self):
|
||||
return self.tensors
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
o = self._copy()
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = t.to(*args, **kwargs)
|
||||
return o
|
||||
|
||||
def new_ones(self, *args, **kwargs):
|
||||
return self.tensors[0].new_ones(*args, **kwargs)
|
||||
|
||||
def float(self):
|
||||
return self.to(dtype=torch.float)
|
||||
|
||||
def chunk(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
||||
|
||||
def size(self):
|
||||
return self.tensors[0].size()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.tensors[0].shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
dims = 0
|
||||
for t in self.tensors:
|
||||
dims = max(t.ndim, dims)
|
||||
return dims
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensors[0].device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self.tensors[0].layout
|
||||
|
||||
|
||||
def cat_nested(tensors, *args, **kwargs):
|
||||
cated_tensors = []
|
||||
for i in range(len(tensors[0].tensors)):
|
||||
tens = []
|
||||
for j in range(len(tensors)):
|
||||
tens.append(tensors[j].tensors[i])
|
||||
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
||||
return NestedTensor(cated_tensors)
|
||||
554
comfy/ops.py
554
comfy/ops.py
@@ -22,20 +22,15 @@ import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import contextlib
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
return
|
||||
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
@@ -55,92 +50,49 @@ try:
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||
try:
|
||||
if comfy.model_management.is_nvidia():
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
#TODO: change upper bound version once it's fixed'
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||
logging.info("working around nvidia conv3d memory bug.")
|
||||
except:
|
||||
pass
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
if isinstance(input, QuantizedTensor):
|
||||
dtype = input.params.orig_dtype
|
||||
else:
|
||||
dtype = input.dtype
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
with wf_context:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
bias_a = bias
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
if weight_has_function or weight.dtype != dtype:
|
||||
weight = weight.to(dtype=dtype)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
@@ -153,13 +105,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -170,13 +119,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -187,13 +133,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -203,23 +146,11 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||
return out
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -230,13 +161,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -248,17 +176,13 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -271,18 +195,13 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -298,15 +217,12 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose2d(
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -322,15 +238,12 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose1d(
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -345,14 +258,10 @@ class disable_weight_init:
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -403,43 +312,50 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Uses QuantizedTensor subclass for dispatch.
|
||||
"""
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
input_dtype = input.dtype
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
input_dtype = input.dtype
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
w = w.t()
|
||||
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if input.ndim != 2:
|
||||
return None
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input_fp8 = input.to(dtype).contiguous()
|
||||
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
if tensor_3d:
|
||||
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
|
||||
return o
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
@@ -449,7 +365,7 @@ class fp8_ops(manual_cast):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
||||
if not self.training:
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
@@ -457,10 +373,57 @@ class fp8_ops(manual_cast):
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
@@ -481,253 +444,10 @@ if CUBLAS_IS_AVAILABLE:
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import (
|
||||
QuantizedTensor,
|
||||
QUANT_ALGOS,
|
||||
TensorCoreFP8Layout,
|
||||
get_layout_class,
|
||||
)
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
# FP8: single tensor scale
|
||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
if tensor_scale is None or block_scale is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=tensor_scale,
|
||||
block_scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue # Already handled above
|
||||
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
layout_cls = self.weight._layout_cls
|
||||
|
||||
# Check if it's any FP8 variant (E4M3 or E5M2)
|
||||
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
||||
elif layout_cls == "TensorCoreNVFP4Layout":
|
||||
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
|
||||
if input.ndim != 2:
|
||||
# Fall back to comfy_cast_weights for non-2D tensors
|
||||
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||
|
||||
# dtype is now implicit in the layout class
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||
|
||||
output = self._forward(input, self.weight, self.bias)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if tensor_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
return weight.dequantize()
|
||||
else:
|
||||
return weight
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
return weight
|
||||
|
||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(curr_value, value)
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
QuantizedTensor,
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
)
|
||||
_CK_AVAILABLE = True
|
||||
if torch.version.cuda is None:
|
||||
ck.registry.disable("cuda")
|
||||
else:
|
||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||
_CK_AVAILABLE = False
|
||||
|
||||
class QuantizedTensor:
|
||||
pass
|
||||
|
||||
class _CKFp8Layout:
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layouts with Comfy-Specific Extensions
|
||||
# ==============================================================================
|
||||
|
||||
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
FP8_DTYPE = None # Must be overridden in subclass
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if cls.FP8_DTYPE is None:
|
||||
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if isinstance(scale, str) and scale == "recalculate":
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
|
||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||
tensor_info = torch.finfo(tensor.dtype)
|
||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||
|
||||
if scale is None:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
||||
|
||||
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e5m2
|
||||
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Registry
|
||||
# ==============================================================================
|
||||
|
||||
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
|
||||
},
|
||||
"float8_e5m2": {
|
||||
"storage_t": torch.float8_e5m2,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
|
||||
},
|
||||
"nvfp4": {
|
||||
"storage_t": torch.uint8,
|
||||
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
||||
"group_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
# ==============================================================================
|
||||
|
||||
__all__ = [
|
||||
"QuantizedTensor",
|
||||
"QuantizedLayout",
|
||||
"TensorCoreFP8Layout",
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
@@ -4,9 +4,13 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
import logging
|
||||
import comfy.nested_tensor
|
||||
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
@@ -17,29 +21,10 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
return torch.cat(noises, axis=0)
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
if latent_image.is_nested:
|
||||
tensors = latent_image.unbind()
|
||||
noises = []
|
||||
for t in tensors:
|
||||
noises.append(prepare_noise_inner(t, generator, noise_inds))
|
||||
noises = comfy.nested_tensor.NestedTensor(noises)
|
||||
else:
|
||||
noises = prepare_noise_inner(latent_image, generator, noise_inds)
|
||||
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
|
||||
@@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -306,10 +306,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
@@ -353,7 +360,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond}
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
@@ -383,7 +390,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||
out = fn(args)
|
||||
out = fn(args)
|
||||
|
||||
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||
|
||||
@@ -720,7 +727,7 @@ class Sampler:
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
@@ -782,7 +789,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
@@ -792,7 +799,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
@@ -962,11 +969,11 @@ class CFGGuider:
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
@@ -980,10 +987,13 @@ class CFGGuider:
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
@@ -991,7 +1001,7 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
@@ -1004,30 +1014,6 @@ class CFGGuider:
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
if latent_image.is_nested:
|
||||
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
|
||||
noise, _ = comfy.utils.pack_latents(noise.unbind())
|
||||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
if denoise_mask is not None:
|
||||
if denoise_mask.is_nested:
|
||||
denoise_masks = denoise_mask.unbind()
|
||||
denoise_masks = denoise_masks[:len(latent_shapes)]
|
||||
else:
|
||||
denoise_masks = [denoise_mask]
|
||||
|
||||
for i in range(len(denoise_masks), len(latent_shapes)):
|
||||
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||
|
||||
for i in range(len(denoise_masks)):
|
||||
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||
|
||||
if len(denoise_masks) > 1:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
@@ -1047,7 +1033,7 @@ class CFGGuider:
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
@@ -1055,9 +1041,6 @@ class CFGGuider:
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
|
||||
if len(latent_shapes) > 1:
|
||||
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
427
comfy/sd.py
427
comfy/sd.py
@@ -18,7 +18,6 @@ import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
import yaml
|
||||
import math
|
||||
@@ -52,11 +51,6 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -64,8 +58,6 @@ import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
import comfy.taesd.taehv
|
||||
import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
@@ -101,7 +93,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@@ -129,32 +121,9 @@ class CLIP:
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if len(state_dict) > 0:
|
||||
if isinstance(state_dict, list):
|
||||
for c in state_dict:
|
||||
m, u = self.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
else:
|
||||
m, u = self.load_sd(state_dict, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
@@ -173,9 +142,6 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -219,7 +185,6 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
@@ -267,7 +232,6 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
@@ -311,30 +275,22 @@ class VAE:
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
else:
|
||||
VAE_KL_MEM_RATIO = 1.0
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
self.size = None
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@@ -376,69 +332,41 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.downscale_ratio = 32
|
||||
self.upscale_ratio = 32
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.downscale_ratio = 32
|
||||
self.upscale_ratio = 32
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = True
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
if 'decoder.post_quant_conv.weight' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
|
||||
|
||||
if 'bn.running_mean' in sd:
|
||||
ddconfig["batch_norm_latent"] = True
|
||||
self.downscale_ratio *= 2
|
||||
self.upscale_ratio *= 2
|
||||
self.latent_channels *= 4
|
||||
old_memory_used_decode = self.memory_used_decode
|
||||
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
else:
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
self.first_stage_model = AudioOobleckVAE()
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
@@ -486,20 +414,20 @@ class VAE:
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.latent_channels = 32
|
||||
self.latent_channels = 64
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = False
|
||||
self.not_video = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@@ -511,10 +439,8 @@ class VAE:
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||
@@ -543,22 +469,17 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
dim = sd["decoder.head.0.gamma"].shape[0]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
||||
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
# Hunyuan 3d v2 2.0 & 2.1
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
|
||||
@@ -588,7 +509,6 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
@@ -606,54 +526,6 @@ class VAE:
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
|
||||
sample_rate = 16000
|
||||
if sample_rate == 16000:
|
||||
mode = '16k'
|
||||
else:
|
||||
mode = '44k'
|
||||
|
||||
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
|
||||
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 20
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.downscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
|
||||
self.latent_channels = sd["decoder.1.weight"].shape[1]
|
||||
self.latent_dim = 3
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
latent_format=comfy.latent_formats.HunyuanVideo
|
||||
else:
|
||||
latent_format=None # lighttaew2_1 doesn't need scaling
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -681,44 +553,20 @@ class VAE:
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size is not None:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if self.crop_input:
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
|
||||
if pixels.shape[-1] > self.output_channels:
|
||||
pixels = pixels[..., :self.output_channels]
|
||||
elif pixels.shape[-1] < self.output_channels:
|
||||
if self.pad_channel_value is not None:
|
||||
if isinstance(self.pad_channel_value, str):
|
||||
mode = self.pad_channel_value
|
||||
value = None
|
||||
else:
|
||||
mode = "constant"
|
||||
value = self.pad_channel_value
|
||||
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
@@ -788,9 +636,6 @@ class VAE:
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@@ -806,13 +651,6 @@ class VAE:
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@@ -859,7 +697,6 @@ class VAE:
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
do_tile = False
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
@@ -881,13 +718,6 @@ class VAE:
|
||||
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
@@ -1006,20 +836,12 @@ class CLIPType(Enum):
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
NEWBIE = 24
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if model_options.get("custom_operations", None) is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||
clip_data.append(sd)
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
@@ -1036,14 +858,6 @@ class TEModel(Enum):
|
||||
QWEN25_3B = 10
|
||||
QWEN25_7B = 11
|
||||
BYT5_SMALL_GLYPH = 12
|
||||
GEMMA_3_4B = 13
|
||||
MISTRAL3_24B = 14
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
GEMMA_3_12B = 18
|
||||
JINA_CLIP_2 = 19
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -1052,8 +866,6 @@ def detect_te_model(sd):
|
||||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
@@ -1068,10 +880,6 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
@@ -1080,18 +888,6 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
else:
|
||||
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
|
||||
@@ -1141,7 +937,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
@@ -1165,7 +961,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
@@ -1188,13 +984,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
@@ -1206,26 +998,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
@@ -1265,30 +1044,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
|
||||
clip_data_gemma = clip_data[0]
|
||||
clip_data_jina = clip_data[1]
|
||||
else:
|
||||
clip_data_gemma = clip_data[1]
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@@ -1304,7 +1059,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
@@ -1363,10 +1125,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
@@ -1375,22 +1133,18 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
@@ -1408,33 +1162,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
scaled_fp8_list = []
|
||||
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||
if k.endswith(".scaled_fp8"):
|
||||
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||
|
||||
if len(scaled_fp8_list) > 0:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
@@ -1451,7 +1194,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@@ -1481,14 +1224,11 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
@@ -1514,7 +1254,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if dtype is None:
|
||||
@@ -1522,15 +1262,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
@@ -1544,8 +1278,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
@@ -1569,9 +1303,6 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
|
||||
@@ -90,6 +90,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
@@ -107,17 +108,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
config[k] = v
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
quant_config = model_options.get("quantization_metadata", None)
|
||||
scaled_fp8 = None
|
||||
|
||||
if operations is None:
|
||||
if quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
logging.info("Using MixedPrecisionOps for text encoder")
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
@@ -135,7 +138,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
self.return_attention_masks = return_attention_masks
|
||||
self.execution_device = None
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
@@ -152,8 +154,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
if self.layer == "all":
|
||||
pass
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
@@ -165,7 +166,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer = self.options_default[0]
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
self.execution_device = None
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
@@ -249,20 +249,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||
|
||||
def forward(self, tokens):
|
||||
if self.execution_device is None:
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
else:
|
||||
device = self.execution_device
|
||||
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
if isinstance(self.layer, list):
|
||||
intermediate_output = self.layer
|
||||
elif self.layer == "all":
|
||||
if self.layer == "all":
|
||||
intermediate_output = "all"
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
@@ -466,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
@@ -474,7 +468,6 @@ class SDTokenizer:
|
||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
self.pad_left = pad_left
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
@@ -513,8 +506,6 @@ class SDTokenizer:
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_key = embedding_key
|
||||
|
||||
self.disable_weights = disable_weights
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
@@ -531,12 +522,6 @@ class SDTokenizer:
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
for i in range(amount):
|
||||
tokens.insert(0, (self.pad_token, 1.0, 0))
|
||||
else:
|
||||
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||
'''
|
||||
@@ -549,7 +534,7 @@ class SDTokenizer:
|
||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
if kwargs.get("disable_weights", self.disable_weights):
|
||||
if kwargs.get("disable_weights", False):
|
||||
parsed_weights = [(text, 1.0)]
|
||||
else:
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
@@ -615,7 +600,7 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
self.pad_tokens(batch, remaining_length)
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
@@ -629,11 +614,11 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if min_padding is not None:
|
||||
self.pad_tokens(batch, min_padding)
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||
self.pad_tokens(batch, self.max_length - len(batch))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if min_length is not None and len(batch) < min_length:
|
||||
self.pad_tokens(batch, min_length - len(batch))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
|
||||
@@ -21,14 +21,11 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
||||
from . import diffusers_convert
|
||||
import comfy.model_management
|
||||
|
||||
class SD15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -542,7 +539,7 @@ class SD3(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
|
||||
memory_usage_factor = 1.6
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -744,37 +741,6 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.02,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@@ -836,21 +802,6 @@ class LTXV(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
class LTXAV(LTXV):
|
||||
unet_config = {
|
||||
"image_model": "ltxav",
|
||||
}
|
||||
|
||||
latent_format = latent_formats.LTXAV
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.055 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
return out
|
||||
|
||||
class HunyuanVideo(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
@@ -981,7 +932,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
@@ -1012,7 +963,7 @@ class Lumina2(supported_models_base.BASE):
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.4
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
@@ -1031,32 +982,6 @@ class Lumina2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
||||
|
||||
class ZImage(Lumina2):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
"dim": 3840,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1311,7 +1236,7 @@ class ChromaRadiance(Chroma):
|
||||
latent_format = comfy.latent_formats.ChromaRadiance
|
||||
|
||||
# Pixel-space model, no spatial compression for model input.
|
||||
memory_usage_factor = 0.044
|
||||
memory_usage_factor = 0.038
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ChromaRadiance(self, device=device)
|
||||
@@ -1354,7 +1279,7 @@ class Omnigen2(supported_models_base.BASE):
|
||||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.95 #TODO
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
@@ -1419,7 +1344,7 @@ class HunyuanImage21(HunyuanVideo):
|
||||
|
||||
latent_format = latent_formats.HunyuanImage21
|
||||
|
||||
memory_usage_factor = 8.7
|
||||
memory_usage_factor = 7.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@@ -1449,108 +1374,6 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 7.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
"in_channels": 98,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 10.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
"model_dim": 2560,
|
||||
"visual_embed_dim": 64,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Flux
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5Image(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
@@ -50,7 +49,7 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
quant_config = None # quantization configuration for mixed precision
|
||||
scaled_fp8 = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
@@ -118,7 +117,3 @@ class BASE:
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
def __getattr__(self, name):
|
||||
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
|
||||
return None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user