Compare commits

..

1 Commits

Author SHA1 Message Date
Yoland Yan
ade9a4e034 Add Window run back 2025-05-20 10:13:06 -07:00
196 changed files with 5302 additions and 177249 deletions

View File

@@ -4,9 +4,6 @@ if you have a NVIDIA gpu:
run_nvidia_gpu.bat
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
run_nvidia_gpu_fast_fp16_accumulation.bat
To run it in slow CPU mode:

1
.gitattributes vendored
View File

@@ -1,3 +1,2 @@
/web/assets/** linguist-generated
/web/** linguist-vendored
comfy_api_nodes/apis/__init__.py linguist-generated

View File

@@ -15,14 +15,6 @@ body:
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Expected Behavior

View File

@@ -11,14 +11,6 @@ body:
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Your question

View File

@@ -1,40 +0,0 @@
name: Check for Windows Line Endings
on:
pull_request:
branches: ['*'] # Trigger on all pull requests to any branch
jobs:
check-line-endings:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history to compare changes
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
# Flag to track if CRLF is found
CRLF_FOUND=false
# Loop through each changed file
for FILE in $CHANGED_FILES; do
# Check if the file exists and is a text file
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
# Check for CRLF line endings
if grep -UP '\r$' "$FILE"; then
echo "Error: Windows line endings (CRLF) detected in $FILE"
CRLF_FOUND=true
fi
fi
done
# Exit with error if CRLF was found
if [ "$CRLF_FOUND" = true ]; then
exit 1
fi

View File

@@ -1,108 +0,0 @@
name: Release Webhook
on:
release:
types: [published]
jobs:
send-webhook:
runs-on: ubuntu-latest
steps:
- name: Send release webhook
env:
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
run: |
# Generate UUID for delivery ID
DELIVERY_ID=$(uuidgen)
HOOK_ID="release-webhook-$(date +%s)"
# Create webhook payload matching GitHub release webhook format
PAYLOAD=$(cat <<EOF
{
"action": "published",
"release": {
"id": ${{ github.event.release.id }},
"node_id": "${{ github.event.release.node_id }}",
"url": "${{ github.event.release.url }}",
"html_url": "${{ github.event.release.html_url }}",
"assets_url": "${{ github.event.release.assets_url }}",
"upload_url": "${{ github.event.release.upload_url }}",
"tag_name": "${{ github.event.release.tag_name }}",
"target_commitish": "${{ github.event.release.target_commitish }}",
"name": ${{ toJSON(github.event.release.name) }},
"body": ${{ toJSON(github.event.release.body) }},
"draft": ${{ github.event.release.draft }},
"prerelease": ${{ github.event.release.prerelease }},
"created_at": "${{ github.event.release.created_at }}",
"published_at": "${{ github.event.release.published_at }}",
"author": {
"login": "${{ github.event.release.author.login }}",
"id": ${{ github.event.release.author.id }},
"node_id": "${{ github.event.release.author.node_id }}",
"avatar_url": "${{ github.event.release.author.avatar_url }}",
"url": "${{ github.event.release.author.url }}",
"html_url": "${{ github.event.release.author.html_url }}",
"type": "${{ github.event.release.author.type }}",
"site_admin": ${{ github.event.release.author.site_admin }}
},
"tarball_url": "${{ github.event.release.tarball_url }}",
"zipball_url": "${{ github.event.release.zipball_url }}",
"assets": ${{ toJSON(github.event.release.assets) }}
},
"repository": {
"id": ${{ github.event.repository.id }},
"node_id": "${{ github.event.repository.node_id }}",
"name": "${{ github.event.repository.name }}",
"full_name": "${{ github.event.repository.full_name }}",
"private": ${{ github.event.repository.private }},
"owner": {
"login": "${{ github.event.repository.owner.login }}",
"id": ${{ github.event.repository.owner.id }},
"node_id": "${{ github.event.repository.owner.node_id }}",
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
"url": "${{ github.event.repository.owner.url }}",
"html_url": "${{ github.event.repository.owner.html_url }}",
"type": "${{ github.event.repository.owner.type }}",
"site_admin": ${{ github.event.repository.owner.site_admin }}
},
"html_url": "${{ github.event.repository.html_url }}",
"clone_url": "${{ github.event.repository.clone_url }}",
"git_url": "${{ github.event.repository.git_url }}",
"ssh_url": "${{ github.event.repository.ssh_url }}",
"url": "${{ github.event.repository.url }}",
"created_at": "${{ github.event.repository.created_at }}",
"updated_at": "${{ github.event.repository.updated_at }}",
"pushed_at": "${{ github.event.repository.pushed_at }}",
"default_branch": "${{ github.event.repository.default_branch }}",
"fork": ${{ github.event.repository.fork }}
},
"sender": {
"login": "${{ github.event.sender.login }}",
"id": ${{ github.event.sender.id }},
"node_id": "${{ github.event.sender.node_id }}",
"avatar_url": "${{ github.event.sender.avatar_url }}",
"url": "${{ github.event.sender.url }}",
"html_url": "${{ github.event.sender.html_url }}",
"type": "${{ github.event.sender.type }}",
"site_admin": ${{ github.event.sender.site_admin }}
}
}
EOF
)
# Generate HMAC-SHA256 signature
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
# Send webhook with required headers
curl -X POST "$WEBHOOK_URL" \
-H "Content-Type: application/json" \
-H "X-GitHub-Event: release" \
-H "X-GitHub-Delivery: $DELIVERY_ID" \
-H "X-GitHub-Hook-ID: $HOOK_ID" \
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
-d "$PAYLOAD" \
--fail --silent --show-error
echo "✅ Release webhook sent successfully"

View File

@@ -102,4 +102,5 @@ jobs:
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
draft: true
prerelease: true
make_latest: false

View File

@@ -20,9 +20,9 @@ jobs:
strategy:
fail-fast: false
matrix:
# os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
os: [macos, linux, windows]
# os: [macos, linux]
python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
@@ -32,9 +32,9 @@ jobs:
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
@@ -46,28 +46,28 @@ jobs:
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
# test-win-nightly:
# strategy:
# fail-fast: true
# matrix:
# os: [windows]
# python_version: ["3.9", "3.10", "3.11", "3.12"]
# cuda_version: ["12.1"]
# torch_version: ["nightly"]
# include:
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
# runs-on: ${{ matrix.runner_label }}
# steps:
# - name: Test Workflows
# uses: comfy-org/comfy-action@main
# with:
# os: ${{ matrix.os }}
# python_version: ${{ matrix.python_version }}
# torch_version: ${{ matrix.torch_version }}
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
# comfyui_flags: ${{ matrix.flags }}
test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "128"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "5"
default: "2"
# push:
# branches:
# - master
@@ -53,8 +53,6 @@ jobs:
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd

View File

@@ -5,20 +5,20 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne

View File

@@ -6,7 +6,6 @@
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Twitter][twitter-shield]][twitter-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
@@ -21,8 +20,6 @@
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
@@ -55,7 +52,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- SD1.x, SD2.x,
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
@@ -65,20 +62,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -86,10 +76,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
@@ -100,19 +89,20 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Release Process
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)
@@ -180,6 +170,10 @@ If you have trouble extracting it, right click the file -> properties -> unblock
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
## Jupyter Notebook
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
@@ -203,7 +197,7 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
@@ -237,11 +231,11 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
#### Troubleshooting
@@ -274,8 +268,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
@@ -295,13 +287,6 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py`
#### Iluvatar Corex
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
# Running
```python main.py```

View File

@@ -1,84 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME

View File

@@ -1,4 +0,0 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

View File

@@ -1,64 +0,0 @@
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,28 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -1,40 +0,0 @@
"""init
Revision ID: e9c714da8d57
Revises:
Create Date: 2025-05-30 20:14:33.772039
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e9c714da8d57'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('file_name', sa.Text(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('hash_algorithm', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model')
# ### end Alembic commands ###

View File

@@ -1,112 +0,0 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

View File

@@ -1,59 +0,0 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
class Model(Base):
"""
sqlalchemy model representing a model file in the system.
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
def to_dict(self):
"""
Convert the model instance to a dictionary representation.
Returns:
dict: A dictionary containing the attributes of the model
"""
dict = to_dict(self)
return dict

View File

@@ -16,61 +16,40 @@ from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"""
{get_missing_requirements_message()}
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
def get_installed_frontend_version():
"""Get the currently installed frontend package version."""
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
try:
frontend_version_str = get_installed_frontend_version()
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
required_frontend_str = get_required_frontend_version()
required_frontend = parse_version(required_frontend_str)
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning(
f"""
@@ -142,22 +121,9 @@ class FrontEndProvider:
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
@cached_property
def latest_prerelease(self) -> Release:
"""Get the latest pre-release version - even if it's older than the latest release"""
release = [release for release in self.all_releases if release["prerelease"]]
if not release:
raise ValueError("No pre-releases found")
# GitHub returns releases in reverse chronological order, so first is latest
return release[0]
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
elif version == "prerelease":
return self.latest_prerelease
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
@@ -196,35 +162,10 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def get_required_frontend_version(cls) -> str:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@@ -245,15 +186,6 @@ comfyui-frontend-package is not installed.
@classmethod
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
Returns:
str: The path to the workflow templates directory.
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@@ -273,37 +205,19 @@ comfyui-workflow-templates is not installed.
""".strip()
)
@classmethod
def embedded_docs_path(cls) -> str:
"""Get the path to embedded documentation"""
try:
import comfyui_embedded_docs
return str(
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
return None
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str, str]: A tuple containing (owner, repo, version).
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
@@ -315,22 +229,18 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Initializes the frontend for the specified version.
Args:
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@@ -382,17 +292,13 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string specifying which frontend to use.
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)

View File

@@ -130,21 +130,10 @@ class ModelFileManager:
for file_name in filenames:
try:
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
@@ -155,7 +144,7 @@ class ModelFileManager:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs, time.perf_counter()
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)

View File

@@ -1,331 +0,0 @@
import os
import logging
import time
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
result = get_relative_path(model_path)
if not result:
logging.error(
f"Model file not in a recognized model directory: {model_path}"
)
return None
return result
except Exception as e:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
def _get_existing_model(self, session, model_type, model_relative_path):
return (
session.query(Model)
.filter(Model.type == model_type)
.filter(Model.path == model_relative_path)
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if not model:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
logging.error(
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
return None
model_type, model_relative_path = result
with create_session() as session:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if model and model.hash:
return model.hash
return None
except Exception as e:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()

View File

@@ -20,15 +20,13 @@ class FileInfo(TypedDict):
path: str
size: int
modified: int
created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path),
"created": os.path.getctime(path)
"modified": os.path.getmtime(path)
}

View File

@@ -49,8 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
@@ -89,7 +88,6 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -145,7 +143,6 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -153,7 +150,6 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
@@ -206,12 +202,6 @@ parser.add_argument(
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@@ -37,8 +37,6 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"

View File

@@ -1,7 +1,6 @@
import torch
import math
import comfy.utils
import logging
class CONDRegular:
@@ -11,15 +10,12 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
def concat(self, others):
@@ -28,19 +24,15 @@ class CONDRegular:
conds.append(x.cond)
return torch.cat(conds)
def size(self):
return list(self.cond.size())
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, area, **kwargs):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
class CONDCrossAttn(CONDRegular):
@@ -55,9 +47,6 @@ class CONDCrossAttn(CONDRegular):
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True
def concat(self, others):
@@ -75,12 +64,11 @@ class CONDCrossAttn(CONDRegular):
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, **kwargs):
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@@ -90,48 +78,3 @@ class CONDConstant(CONDRegular):
def concat(self, others):
return self.cond
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@@ -28,7 +28,6 @@ import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
import comfy.model_base
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@@ -44,6 +43,7 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@@ -265,12 +265,12 @@ class ControlNet(ControlBase):
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
extra[c] = temp.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
@@ -390,9 +390,8 @@ class ControlLora(ControlNet):
pass
for k in self.control_weights:
if (k not in {"lora_controlnet"}):
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
if k not in {"lora_controlnet"}:
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)

View File

@@ -1,10 +1,55 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):

View File

@@ -1,121 +0,0 @@
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
# Codebase ref: https://github.com/scxue/SA-Solver
import math
from typing import Union, Callable
import torch
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
Integral of exp((1 + tau^2) * x) * x^p dx
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
with base case p=0 where integral equals product_terms[0].
where
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
Return coefficients used by the SA-Solver in data prediction mode.
Args:
s: Start time s.
t: End time t.
solver_order: Current order of the solver.
tau_t: Stochastic strength parameter in the SDE.
Returns:
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order1, shape (solver_order,).
"""
tau_mul = 1 + tau_t ** 2
h = t - s
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
# product_terms after factoring out exp((1 + tau^2) * t)
# Includes (1 + tau^2) factor from outside the integral
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
# Lower triangular recursive coefficient matrix
# Accumulates recursive coefficients based on p / (1 + tau^2)
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
log_factorial = (p + 1).lgamma()
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
if tau_t > 0:
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
return recursive_coeff_mat @ product_terms_factored
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
tau_mul = 1 + tau_t ** 2
h = lambda_t - lambda_s
alpha_t = sigma_next * lambda_t.exp()
if is_corrector_step:
# Simplified 1-step (order-2) corrector
b_1 = alpha_t * (0.5 * tau_mul * h)
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
else:
# Simplified 2-step predictor
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
return torch.stack([b_2, b_1])
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
The solver order corresponds to the number of input lambdas (half-logSNR points).
Args:
sigma_next: Sigma at end time t.
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
lambda_s: Lambda at start time s.
lambda_t: Lambda at end time t.
tau_t: Stochastic strength parameter in the SDE.
simple_order_2: Whether to enable the simple order-2 scheme.
is_corrector_step: Flag for corrector step in simple order-2 mode.
Returns:
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
"""
num_timesteps = curr_lambdas.shape[0]
if simple_order_2 and num_timesteps == 2:
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
# Compute coefficients by solving a linear system from Lagrange basis interpolation
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
# = sigma_t * exp(lambda_t) = alpha_t
# exp((1 + tau^2) * lambda_t) is extracted from the integral
alpha_t = sigma_next * lambda_t.exp()
return alpha_t * lagrange_integrals
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
"""Return a function that controls the stochasticity of SA-Solver.
When eta = 0, SA-Solver runs as ODE. The official approach uses
time t to determine the SDE interval, while here we use sigma instead.
See:
https://github.com/scxue/SA-Solver/blob/main/README.md
"""
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
if eta <= 0:
return 0.0 # ODE
if isinstance(sigma, torch.Tensor):
sigma = sigma.item()
return eta if start_sigma >= sigma >= end_sigma else 0.0
return tau_func

View File

@@ -1,5 +1,4 @@
import math
from functools import partial
from scipy import integrate
import torch
@@ -9,7 +8,6 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
@@ -144,33 +142,6 @@ class BrownianTreeNoiseSampler:
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(model_sampling, comfy.model_sampling.CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -413,13 +384,9 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@@ -715,7 +682,6 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -727,49 +693,38 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
lambda_s_1 = lambda_s + r * h
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
alpha_s = sigmas[i] * lambda_s.exp()
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
lambda_s_1_ = sd.log().neg()
h_ = lambda_s_1_ - lambda_s
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
if eta > 0 and s_noise > 0:
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
lambda_t_ = sd.log().neg()
h_ = lambda_t_ - lambda_s
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
return x
@@ -798,7 +753,6 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
@@ -814,12 +768,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None
h, h_last = None, None
h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -830,29 +781,26 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised
else:
# DPM-Solver++(2M) SDE
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -866,10 +814,6 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@@ -881,16 +825,13 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
# Denoising step
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
h_eta = h * (eta + 1)
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
@@ -899,22 +840,20 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
x = x + phi_2 * d1 - phi_3 * d2
elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + (alpha_t * phi_2) * d
x = x + phi_2 * d
if eta > 0 and s_noise > 0:
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
@@ -924,7 +863,6 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -934,7 +872,6 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
@@ -1072,9 +1009,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
@@ -1092,7 +1027,6 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
@@ -1116,9 +1050,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
@@ -1158,7 +1090,6 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@torch.no_grad()
@@ -1209,22 +1140,39 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps (CFG++)."""
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
uncond_denoised = None
temp = [0]
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
@@ -1233,33 +1181,15 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
# DDIM stochastic sampling
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
sigma_down = alpha_t * sigma_down
# Euler method
x = alpha_t * denoised + sigma_down * d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
d = to_d(x, sigmas[i], temp[0])
# Euler method
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Euler method steps (CFG++)."""
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
@@ -1416,7 +1346,6 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1443,32 +1372,31 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
if i == 0:
# Euler method
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt
if i >= 1:
# Gradient estimation
else:
# Gradient estimation
if cfg_pp:
d_bar = (ge_gamma - 1) * (d - old_d)
x = denoised + d * sigmas[i + 1] + d_bar * dt
else:
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
@@ -1476,18 +1404,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_er_sde_noise_scaler(x):
return x * ((x ** 0.3).exp() + 10.0)
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
old_denoised = None
old_denoised_d = None
@@ -1498,45 +1420,41 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigmas[i] / er_lambda_s
alpha_t = sigmas[i + 1] / er_lambda_t
r_alpha = alpha_t / alpha_s
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
# Stage 2
s = torch.sum(1 / scaled_pos) * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
old_denoised_d = denoised_d
if s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
"""
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1544,11 +1462,6 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
@@ -1556,43 +1469,38 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r * h
s = t + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
sigma_s = s.neg().exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
"""
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1600,11 +1508,6 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
@@ -1612,150 +1515,34 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x
@torch.no_grad()
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
if tau_func is None:
# Use default interval for stochastic sampling
start_sigma = model_sampling.percent_to_sigma(0.2)
end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
max_used_order = max(predictor_order, corrector_order)
x_pred = x # x: current state, x_pred: predicted next state
h = 0.0
tau_t = 0.0
noise = 0.0
pred_list = []
# Lower order near the end to improve stability
lower_order_to_end = sigmas[-1].item() == 0
for i in trange(len(sigmas) - 1, disable=disable):
# Evaluation
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
pred_list.append(denoised)
pred_list = pred_list[-max_used_order:]
predictor_order_used = min(predictor_order, len(pred_list))
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
corrector_order_used = 0
else:
corrector_order_used = min(corrector_order, len(pred_list))
if lower_order_to_end:
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
# Corrector
if corrector_order_used == 0:
# Update by the predicted state
x = x_pred
else:
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i],
curr_lambdas,
lambdas[i - 1],
lambdas[i],
tau_t,
simple_order_2,
is_corrector_step=True,
)
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
if tau_t > 0 and s_noise > 0:
# The noise from the previous predictor step
x = x + noise
if use_pece:
# Evaluate the corrected state
denoised = model(x, sigmas[i] * s_in, **extra_args)
pred_list[-1] = denoised
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i + 1],
curr_lambdas,
lambdas[i],
lambdas[i + 1],
tau_t,
simple_order_2,
is_corrector_step=False,
)
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
h = lambdas[i + 1] - lambdas[i]
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)

View File

@@ -457,82 +457,6 @@ class Wan21(LatentFormat):
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@@ -80,13 +80,15 @@ class DoubleStreamBlock(nn.Module):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -100,12 +102,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -150,7 +152,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -160,7 +162,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@@ -176,6 +178,6 @@ class LastLayer(nn.Module):
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@@ -163,7 +163,7 @@ class Chroma(nn.Module):
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too
@@ -254,12 +254,13 @@ class Chroma(nn.Module):
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
@@ -267,4 +268,4 @@ class Chroma(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -26,6 +26,16 @@ from torch import nn
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()

View File

@@ -58,8 +58,7 @@ def is_odd(n: int) -> bool:
def nonlinearity(x):
# x * sigmoid(x)
return torch.nn.functional.silu(x)
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -66,16 +66,15 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
enable_fps_modulation: bool = True,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
self.enable_fps_modulation = enable_fps_modulation
dim = head_dim
dim_h = dim // 6 * 2
@@ -133,19 +132,21 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
B, T, H, W, _ = B_T_H_W_C
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None or self.enable_fps_modulation is False: # image case
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
if fps is None: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)

View File

@@ -1,864 +0,0 @@
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
import logging
from typing import Callable, Optional, Tuple
import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
# ---------------------- Feed Forward Network -----------------------
class GPT2FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
super().__init__()
self.activation = nn.GELU()
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
self._layer_id = None
self._dim = d_model
self._hidden_dim = d_ff
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
attention, and rearranges the output back to the original format.
The input tensor names use the following dimension conventions:
- B: batch size
- S: sequence length
- H: number of attention heads
- D: head dimension
Args:
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
Returns:
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):
"""
A flexible attention module supporting both self-attention and cross-attention mechanisms.
This module implements a multi-head attention layer that can operate in either self-attention
or cross-attention mode. The mode is determined by whether a context dimension is provided.
The implementation uses scaled dot-product attention and supports optional bias terms and
dropout regularization.
Args:
query_dim (int): The dimensionality of the query vectors.
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
If None, the module operates in self-attention mode using query_dim. Default: None
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
head_dim (int, optional): The dimension of each attention head. Default: 64
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
Examples:
>>> # Self-attention with 512 dimensions and 8 heads
>>> self_attn = Attention(query_dim=512)
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
>>> out = self_attn(x) # (32, 16, 512)
>>> # Cross-attention
>>> cross_attn = Attention(query_dim=512, context_dim=256)
>>> query = torch.randn(32, 16, 512)
>>> context = torch.randn(32, 8, 256)
>>> out = cross_attn(query, context) # (32, 16, 512)
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
n_heads: int = 8,
head_dim: int = 64,
dropout: float = 0.0,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
logging.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{n_heads} heads with a dimension of {head_dim}."
)
self.is_selfattn = context_dim is None # self attention
context_dim = query_dim if context_dim is None else context_dim
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.v_norm = nn.Identity()
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
def compute_qkv(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_proj(x)
context = x if context is None else context
k = self.k_proj(context)
v = self.v_proj(context)
q, k, v = map(
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
(q, k, v),
)
def apply_norm_and_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_norm(q)
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
class Timesteps(nn.Module):
def __init__(self, num_channels: int):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
timesteps = timesteps_B_T.flatten().float()
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.in_dim = in_features
self.out_dim = out_features
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_T_3D = emb
emb_B_T_D = sample
else:
adaln_lora_B_T_3D = None
emb_B_T_D = emb
return emb_B_T_D, adaln_lora_B_T_3D
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size: int,
temporal_patch_size: int,
in_channels: int = 3,
out_channels: int = 768,
device=None, dtype=None, operations=None
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert (
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size: int,
spatial_patch_size: int,
temporal_patch_size: int,
out_channels: int,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
if use_adaln_lora:
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
)
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
scale_B_T_D, "b t d -> b t 1 1 d"
)
def _fn(
_x_B_T_H_W_D: torch.Tensor,
_norm_layer: nn.Module,
_scale_B_T_1_1_D: torch.Tensor,
_shift_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
return x_B_T_H_W_O
class Block(nn.Module):
"""
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
The block applies the following sequence:
1. Self-attention with AdaLN modulation
2. Cross-attention with AdaLN modulation
3. MLP with AdaLN modulation
Each component uses skip connections and layer normalization.
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.x_dim = x_dim
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
)
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
self.use_adaln_lora = use_adaln_lora
if self.use_adaln_lora:
self.adaln_modulation_self_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_cross_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_mlp = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
B, T, H, W, D = x_B_T_H_W_D.shape
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_self_attn,
scale_self_attn_B_T_1_1_D,
shift_self_attn_B_T_1_1_D,
)
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
return _result_B_T_H_W_D
result_B_T_H_W_D = _x_fn(
x_B_T_H_W_D,
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_mlp,
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
return x_B_T_H_W_D
class MiniTrainDIT(nn.Module):
"""
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
min_fps (int): Minimum frames per second.
max_fps (int): Maximum frames per second.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: int, # tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
# cross attention settings
crossattn_emb_channels: int = 1024,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
min_fps: int = 1,
max_fps: int = 30,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.min_fps = min_fps
self.max_fps = max_fps
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.build_pos_embed(device=device, dtype=dtype)
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.Sequential(
Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
)
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
device=device, dtype=dtype, operations=operations,
)
self.blocks = nn.ModuleList(
[
Block(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_blocks)
]
)
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
def build_pos_embed(self, device=None, dtype=None) -> None:
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
max_fps=self.max_fps,
min_fps=self.min_fps,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
enable_fps_modulation=self.rope_enable_fps_modulation,
device=device,
)
self.pos_embedder = cls_type(
**kwargs, # type: ignore
)
if self.extra_per_block_abs_pos_emb:
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs, # type: ignore
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is None:
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
else:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
x_B_C_Tt_Hp_Wp = rearrange(
x_B_T_H_W_M,
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
t=self.patch_temporal,
)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
x_B_C_T_H_W = x
timesteps_B_T = timesteps
crossattn_emb = context
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
"""
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x_B_C_T_H_W,
fps=fps,
padding_mask=padding_mask,
)
if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
# for logging purpose
affline_scale_log_info = {}
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
self.affline_scale_log_info = affline_scale_log_info
self.affline_emb = t_embedding_B_T_D
self.crossattn_emb = crossattn_emb
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
block_kwargs = {
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
**block_kwargs,
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -121,11 +121,6 @@ class ControlNetFlux(Flux):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
else:
y = y[:, :self.params.vec_in_dim]
# running on sequences img
img = self.img_in(img)
@@ -179,7 +174,7 @@ class ControlNetFlux(Flux):
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))

View File

@@ -118,7 +118,7 @@ class Modulation(nn.Module):
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return torch.addcmul(m_add, tensor, m_mult)
return tensor * m_mult + m_add
else:
return tensor * m_mult
else:

View File

@@ -101,10 +101,6 @@ class Flux(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -159,9 +155,6 @@ class Flux(nn.Module):
if add is not None:
img += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
@@ -195,50 +188,20 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0):
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.per_channel_statistics = processor()

View File

@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
@@ -19,12 +19,16 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
return z, None
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
class AbstractAutoencoder(torch.nn.Module):

View File

@@ -20,11 +20,8 @@ if model_management.xformers_enabled():
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError as e:
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
if model_management.flash_attention_enabled():
@@ -753,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x = n + x
x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
@@ -793,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x = n + x
x += n
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x = x_skip + x
x += x_skip
return x

View File

@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x)
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -31,7 +31,7 @@ def dynamic_slice(
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
return x[slicing]
class AttnChunk(NamedTuple):

View File

@@ -1,469 +0,0 @@
# Original code: https://github.com/VectorSpaceLab/OmniGen2
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.act = nn.SiLU()
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class LuminaRMSNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x) * (1 + emb)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
self.caption_embedder = nn.Sequential(
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
class Attention(nn.Module):
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = query.view(batch_size, -1, self.heads, self.dim_head)
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
class OmniGen2TransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.attn = Attention(
query_dim=dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
dtype=dtype, device=device, operations=operations,
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
dtype=dtype, device=device, operations=operations
)
if modulation:
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
else:
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
p = self.patch_size
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
max_seq_len = max(seq_lengths)
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
pe_shift = cap_seq_len
pe_shift_len = cap_seq_len
if ref_img_sizes[i] is not None:
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
H, W = ref_img_size
ref_H_tokens, ref_W_tokens = H // p, W // p
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
pe_shift += max(ref_H_tokens, ref_W_tokens)
pe_shift_len += ref_img_len
H, W = img_sizes[i]
H_tokens, W_tokens = H // p, W // p
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = encoder_seq_len
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
ref_img_freqs_cis_shape = list(freqs_cis.shape)
ref_img_freqs_cis_shape[1] = max_ref_img_len
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
class OmniGen2Transformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.layers = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.patch_size
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
flat_ref_img_hidden_states = None
if ref_image_hidden_states is not None:
imgs = []
for ref_img in ref_image_hidden_states:
B, C, H, W = ref_img.size()
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
imgs.append(ref_img)
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
img = hidden_states
B, C, H, W = img.size()
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
return (
flat_hidden_states, flat_ref_img_hidden_states,
None, None,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if ref_image_hidden_states is not None:
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@@ -1,256 +1,256 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@@ -1,400 +0,0 @@
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
self.approximate = approximate
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
inner_dim=None,
bias: bool = True,
dtype=None, device=None, operations=None
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([])
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim,
dtype=dtype,
device=device,
operations=operations
)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
return timesteps_emb
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
dim_head: int = 64,
heads: int = 8,
dropout: float = 0.0,
bias: bool = False,
eps: float = 1e-5,
out_bias: bool = True,
out_dim: int = None,
out_context_dim: int = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim
self.heads = heads
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.dropout = dropout
# Q/K normalization
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
# Image stream projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text stream projections
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Output projections
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
nn.Dropout(dropout)
])
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
img_attn_output = self.to_out[0](img_attn_output)
img_attn_output = self.to_out[1](img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class LastLayer(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine=False,
eps=1e-6,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class QwenImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: Optional[int] = 16,
num_layers: int = 60,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
dtype=dtype,
device=device,
operations=operations
)
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
QwenImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_layers)
])
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
def pos_embeds(self, x, context):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
**kwargs
):
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@@ -146,15 +146,6 @@ WAN_CROSSATTENTION_CLASSES = {
}
def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
class WanAttentionBlock(nn.Module):
def __init__(self,
@@ -211,23 +202,20 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
self.norm1(x) * (1 + e[1]) + e[0],
freqs)
x = x + y * repeat_e(e[2], x)
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * repeat_e(e[5], x)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
@@ -337,12 +325,8 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
if e.ndim < 3:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
@@ -522,9 +506,8 @@ class WanModel(torch.nn.Module):
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
@@ -556,20 +539,13 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
@@ -659,7 +635,7 @@ class VaceWanModel(WanModel):
t,
context,
vace_context,
vace_strength,
vace_strength=1.0,
clip_fea=None,
freqs=None,
transformer_options={},
@@ -685,11 +661,8 @@ class VaceWanModel(WanModel):
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
orig_shape = list(vace_context.shape)
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
c = c.flatten(2).transpose(1, 2)
c = list(c.split(orig_shape[0], dim=0))
# arguments
x_orig = x
@@ -709,9 +682,8 @@ class VaceWanModel(WanModel):
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength[iii]
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength
del c_skip
# head
x = self.head(x, e)
@@ -769,7 +741,8 @@ class CameraWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x = x + self.control_adapter(camera_conditions).to(x.dtype)
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)

View File

@@ -24,17 +24,12 @@ class CausalConv3d(ops.Conv3d):
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
del cache_x
x = F.pad(x, padding)
return super().forward(x)
@@ -57,6 +52,15 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
@@ -69,11 +73,11 @@ class Resample(nn.Module):
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -153,6 +157,29 @@ class Resample(nn.Module):
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
@@ -171,7 +198,7 @@ class ResidualBlock(nn.Module):
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
old_x = x
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@@ -183,12 +210,12 @@ class ResidualBlock(nn.Module):
cache_x.device), cache_x
],
dim=2)
x = layer(x, cache_list=feat_cache, cache_idx=idx)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + self.shortcut(old_x)
return x + h
class AttentionBlock(nn.Module):
@@ -467,6 +494,12 @@ class WanVAE(nn.Module):
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x):
self.clear_cache()
## cache
@@ -512,6 +545,18 @@ class WanVAE(nn.Module):
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]

View File

@@ -1,726 +0,0 @@
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .vae import AttentionBlock, CausalConv3d, RMS_norm
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
# ops.Conv2d(dim, dim//2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] != "Rep"):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] == "Rep"):
cache_x = torch.cat(
[
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + self.shortcut(old_x)
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_upsample=False,
up_flag=False):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i]
if i < len(temperal_downsample) else False)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(
temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1,
))
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(
self,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dec_dim,
z_dim,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_upsample,
dropout,
)
def encode(self, x):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -283,9 +283,8 @@ def model_lora_keys_unet(model, key_map={}):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:
@@ -293,16 +292,6 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.QwenImage):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
# Direct mapping for transformer_blocks format (QwenImage LoRA format)
key_map["{}".format(key_lora)] = k
# Support transformer prefix format
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
return key_map

View File

@@ -34,15 +34,12 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.model_management
import comfy.patcher_extension
@@ -51,7 +48,6 @@ import comfy.ops
from enum import Enum
from . import utils
import comfy.latent_formats
import comfy.model_sampling
import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -67,39 +63,38 @@ class ModelType(Enum):
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
IMG_TO_IMG = 9
FLOW_COSMOS = 10
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
def model_sampling(model_config, model_type):
s = comfy.model_sampling.ModelSamplingDiscrete
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = comfy.model_sampling.EPS
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = comfy.model_sampling.V_PREDICTION
c = V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousEDM
c = V_PREDICTION
s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE:
c = comfy.model_sampling.EPS
s = comfy.model_sampling.StableCascadeSampling
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = comfy.model_sampling.EDM
s = comfy.model_sampling.ModelSamplingContinuousEDM
c = EDM
s = ModelSamplingContinuousEDM
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousV
c = V_PREDICTION
s = ModelSamplingContinuousV
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSampling(s, c):
pass
@@ -107,15 +102,6 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = comfy.model_management.cast_to_device(extra, device, dtype)
else:
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
@@ -149,7 +135,6 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -163,7 +148,7 @@ class BaseModel(torch.nn.Module):
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
dtype = self.get_dtype()
@@ -172,22 +157,16 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = comfy.model_management.cast_to_device(context, device, dtype)
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype, device))
extra = ex
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
@@ -346,28 +325,19 @@ class BaseModel(torch.nn.Module):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
def memory_required(self, input_shape, cond_shapes={}):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
area = input_shape[0] * math.prod(input_shape[2:])
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
def extra_conds_shapes(self, **kwargs):
return {}
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
@@ -402,7 +372,7 @@ class SD21UNCLIP(BaseModel):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels), device=device)
return torch.zeros((1, self.adm_channels))
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
@@ -616,11 +586,9 @@ class IP2P:
if image is None:
image = torch.zeros_like(noise)
else:
image = image.to(device=device)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
@@ -699,7 +667,7 @@ class StableCascade_B(BaseModel):
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior)
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out
@@ -822,7 +790,6 @@ class PixArt(BaseModel):
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
self.memory_usage_factor_conds = ("ref_latents",)
def concat_cond(self, **kwargs):
try:
@@ -883,23 +850,8 @@ class Flux(BaseModel):
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@@ -1024,45 +976,6 @@ class CosmosVideo(BaseModel):
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class CosmosPredict2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
self.image_to_video = image_to_video
if self.image_to_video:
self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
if denoise_mask.ndim <= 4:
return timestep
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
@@ -1103,9 +1016,8 @@ class WAN21(BaseModel):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0])
if extra_channels != image.shape[1] + 4:
if not self.image_to_video or extra_channels == image.shape[1]:
return image
if not self.image_to_video or extra_channels == image.shape[1]:
return image
if image.shape[1] > (extra_channels - 4):
image = image[:, :(extra_channels - 4)]
@@ -1135,11 +1047,6 @@ class WAN21(BaseModel):
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
time_dim_concat = kwargs.get("time_dim_concat", None)
if time_dim_concat is not None:
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
return out
@@ -1155,25 +1062,20 @@ class WAN21_Vace(WAN21):
vace_frames = kwargs.get("vace_frames", None)
if vace_frames is None:
noise_shape[1] = 32
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
for i in range(0, vace_frames.shape[1], 16):
vace_frames = vace_frames.clone()
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
mask = kwargs.get("vace_mask", None)
if mask is None:
noise_shape[1] = 64
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
vace_frames_out.append(vf)
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
vace_frames = torch.stack(vace_frames_out, dim=1)
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
vace_strength = kwargs.get("vace_strength", 1.0)
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
@@ -1189,31 +1091,6 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
@@ -1279,44 +1156,3 @@ class ACEStep(BaseModel):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -346,9 +346,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
@@ -409,83 +407,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = True
dit_config["pos_emb_interpolation"] = "crop"
dit_config["min_fps"] = 1
dit_config["max_fps"] = 30
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 2048:
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 16
elif dit_config["model_channels"] == 5120:
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
if dit_config["in_channels"] == 16:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17: # img to video
if dit_config["model_channels"] == 2048:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["model_channels"] == 5120:
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
dit_config["extra_t_extrapolation_ratio"] = 1.0
dit_config["rope_enable_fps_modulation"] = False
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
dit_config["axes_dim_rope"] = [40, 40, 40]
dit_config["axes_lens"] = [1024, 1664, 1664]
dit_config["ffn_dim_multiplier"] = None
dit_config["hidden_size"] = 2520
dit_config["in_channels"] = 16
dit_config["multiple_of"] = 256
dit_config["norm_eps"] = 1e-05
dit_config["num_attention_heads"] = 21
dit_config["num_kv_heads"] = 7
dit_config["num_layers"] = 32
dit_config["num_refiner_layers"] = 2
dit_config["out_channels"] = None
dit_config["patch_size"] = 2
dit_config["text_feat_dim"] = 2048
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -699,9 +620,6 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
if "conv_in.weight" not in state_dict:
return None
match = {}
transformer_depth = []
@@ -872,7 +790,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)

View File

@@ -101,7 +101,7 @@ if args.directml is not None:
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex # noqa: F401
import intel_extension_for_pytorch as ipex
_ = torch.xpu.device_count()
xpu_available = xpu_available or torch.xpu.is_available()
except:
@@ -128,11 +128,6 @@ try:
except:
mlu_available = False
try:
ixuca_available = hasattr(torch, "corex")
except:
ixuca_available = False
if args.cpu:
cpu_state = CPUState.CPU
@@ -156,12 +151,6 @@ def is_mlu():
return True
return False
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -197,9 +186,8 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
mem_total = mem_total_xpu
mem_total = torch.xpu.get_device_properties(dev).total_memory
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -300,34 +288,21 @@ try:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
SUPPORT_FP8_OPS = args.supports_fp8_compute
try:
if is_amd():
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
except:
pass
@@ -340,7 +315,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
@@ -348,7 +323,7 @@ except:
pass
try:
if torch_version_numeric >= (2, 5):
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@@ -392,8 +367,6 @@ def get_torch_device_name(device):
except:
allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
elif device.type == "xpu":
return "{} {}".format(device, torch.xpu.get_device_name(device))
else:
return "{}".format(device.type)
elif is_intel_xpu():
@@ -529,8 +502,6 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@@ -724,7 +695,7 @@ def unet_inital_load_device(parameters, dtype):
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
if DISABLE_SMART_MEMORY:
return cpu_dev
model_size = dtype_size(dtype) * parameters
@@ -895,7 +866,6 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
@@ -949,7 +919,7 @@ def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
if is_intel_xpu():
return True
return False
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
@@ -988,8 +958,6 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
@@ -1001,15 +969,6 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
@@ -1017,8 +976,6 @@ def sync_stream(device, stream):
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1060,8 +1017,6 @@ def xformers_enabled():
return False
if is_mlu():
return False
if is_ixuca():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -1087,7 +1042,7 @@ def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if is_nvidia():
if is_nvidia(): #pytorch flash attention only works on Nvidia
return True
if is_intel_xpu():
return True
@@ -1097,15 +1052,13 @@ def pytorch_attention_flash_attention():
return True
if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
return False
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
upcast = True
if upcast:
@@ -1129,8 +1082,8 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
@@ -1179,9 +1132,6 @@ def is_device_cpu(device):
def is_device_mps(device):
return is_device_type(device, 'mps')
def is_device_xpu(device):
return is_device_type(device, 'xpu')
def is_device_cuda(device):
return is_device_type(device, 'cuda')
@@ -1213,10 +1163,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_properties(device).has_fp16
return True
if is_ascend_npu():
return True
@@ -1224,9 +1171,6 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True
if is_ixuca():
return True
if torch.version.hip:
return True
@@ -1282,15 +1226,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
if torch_version_numeric < (2, 6):
return True
else:
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
if is_ascend_npu():
return True
if is_ixuca():
if is_ascend_npu():
return True
if is_amd():
@@ -1319,9 +1257,6 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
return False
@@ -1333,22 +1268,15 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
if torch_version_numeric < (2, 3):
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -17,26 +17,23 @@
"""
from __future__ import annotations
import collections
from typing import Optional, Callable
import torch
import copy
import inspect
import logging
import math
import uuid
from typing import Callable, Optional
import collections
import math
import torch
import comfy.float
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data):
crc = 0xFFFFFFFF
@@ -379,9 +376,6 @@ class ModelPatcher:
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function

View File

@@ -77,25 +77,6 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
@@ -369,15 +350,3 @@ class ModelSamplingFlux(torch.nn.Module):
if percent >= 1.0:
return 0.0
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
def timestep(self, sigma):
return sigma / (sigma + 1)
def sigma(self, timestep):
sigma_max = self.sigma_max
if timestep >= (sigma_max / (sigma_max + 1)):
return sigma_max
return timestep / (1 - timestep)

View File

@@ -336,12 +336,9 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.utils
@@ -106,21 +104,6 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
for _, cs in conds.items():
for cond in cs:
for k, v in model.model.extra_conds_shapes(**cond).items():
cond_shapes[k].append(v)
if cond_shapes_min.get(k, None) is None:
cond_shapes_min[k] = [v]
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
cond_shapes_min[k] = [v]
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
@@ -134,8 +117,9 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
real_model = model.model
return real_model, conds, models

View File

@@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)
@@ -256,13 +256,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
@@ -373,11 +367,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
uncond_ = uncond
conds = [cond, uncond_]
if "sampler_calc_cond_batch_function" in model_options:
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = model_options["sampler_calc_cond_batch_function"](args)
else:
out = calc_cond_batch(model, conds, x, timestep, model_options)
out = calc_cond_batch(model, conds, x, timestep, model_options)
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
@@ -720,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@@ -1043,13 +1033,13 @@ class SchedulerHandler(NamedTuple):
use_ms: bool = True
SCHEDULER_HANDLERS = {
"simple": SchedulerHandler(simple_scheduler),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"normal": SchedulerHandler(normal_scheduler),
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"simple": SchedulerHandler(simple_scheduler),
"ddim_uniform": SchedulerHandler(ddim_scheduler),
"beta": SchedulerHandler(beta_scheduler),
"normal": SchedulerHandler(normal_scheduler),
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
}

View File

@@ -14,12 +14,10 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
import os
import comfy.utils
@@ -46,8 +44,6 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.model_patcher
import comfy.lora
@@ -422,30 +418,17 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.latent_channels = 48
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
@@ -771,8 +754,6 @@ class CLIPType(Enum):
HIDREAM = 14
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -792,8 +773,6 @@ class TEModel(Enum):
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -814,12 +793,6 @@ def detect_te_model(sd):
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -921,12 +894,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1002,12 +969,6 @@ def load_gligen(ckpt_path):
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
if 'lora' in filename.lower():
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
return ""
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@@ -1036,7 +997,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
@@ -1120,28 +1081,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.
The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files
@@ -1199,7 +1139,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in diffusion model: {}".format(left_over))
logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
@@ -1207,8 +1147,8 @@ def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def load_unet(unet_path, dtype=None):

View File

@@ -462,7 +462,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.min_length = min_length
self.end_token = None
self.min_padding = min_padding
@@ -482,8 +482,7 @@ class SDTokenizer:
if end_token is not None:
self.end_token = end_token
else:
if has_end_token:
self.end_token = empty[0]
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token

View File

@@ -18,7 +18,7 @@
"single_word": false
},
"errors": "replace",
"model_max_length": 8192,
"model_max_length": 77,
"name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",

View File

@@ -18,8 +18,6 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
from . import supported_models_base
from . import latent_formats
@@ -910,48 +908,6 @@ class CosmosI2V(CosmosT2V):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
class CosmosT2IPredict2(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 16,
}
sampling_settings = {
"sigma_data": 1.0,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 17,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
return out
class Lumina2(supported_models_base.BASE):
unet_config = {
"image_model": "lumina2",
@@ -1060,19 +1016,6 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"out_dim": 48,
}
latent_format = latent_formats.Wan22
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1196,70 +1139,6 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
class Omnigen2(supported_models_base.BASE):
unet_config = {
"image_model": "omnigen2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Omnigen2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.15,
}
memory_usage_factor = 1.8 #TODO
unet_extra_config = {}
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.QwenImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]

View File

@@ -24,41 +24,6 @@ class Llama2Config:
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 11008
num_hidden_layers: int = 36
num_attention_heads: int = 16
num_key_value_heads: int = 2
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
hidden_size: int = 3584
intermediate_size: int = 18944
num_hidden_layers: int = 28
num_attention_heads: int = 28
num_key_value_heads: int = 4
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
@@ -75,7 +40,6 @@ class Gemma2_2B_Config:
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -134,9 +98,9 @@ class Attention(nn.Module):
self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
@@ -356,23 +320,6 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):

View File

@@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@@ -1,44 +0,0 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_

View File

@@ -1,42 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

File diff suppressed because it is too large Load Diff

View File

@@ -1,241 +0,0 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<|img|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151666": {
"content": "<|endofimg|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151667": {
"content": "<|meta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151668": {
"content": "<|endofmeta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"extra_special_tokens": {},
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"processor_class": "Qwen2_5_VLProcessor",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

File diff suppressed because one or more lines are too long

View File

@@ -1,71 +0,0 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
import torch
import numbers
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
out = out[:, template_end:]
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask") # attention mask is useless if no masked elements
return out, pooled, extra
def te(dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
)
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values.contiguous()
return values
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
q = self.q(x)

View File

@@ -31,7 +31,6 @@ from einops import rearrange
from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@@ -50,30 +49,19 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@@ -89,7 +77,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
@@ -101,13 +88,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
try:
from app.model_processor import model_processor
model_processor.process_file(ckpt)
except Exception as e:
logging.error(f"Error processing file {ckpt}: {e}")
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):
@@ -713,26 +693,6 @@ def resize_to_batch_size(tensor, batch_size):
return output
def resize_list_to_batch_size(l, batch_size):
in_batch_size = len(l)
if in_batch_size == batch_size or in_batch_size == 0:
return l
if batch_size <= 1:
return l[:batch_size]
output = []
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output.append(l[min(round(i * scale), in_batch_size - 1)])
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
return output
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
@@ -1037,12 +997,11 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total, node_id=None):
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id
def update_absolute(self, value, total=None, preview=None):
if total is not None:
@@ -1051,7 +1010,7 @@ class ProgressBar:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview, node_id=self.node_id)
self.hook(self.current, self.total, preview)
def update(self, value):
self.update_absolute(self.current + value)

View File

@@ -1,4 +1,4 @@
from .base import WeightAdapterBase, WeightAdapterTrainBase
from .base import WeightAdapterBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
@@ -15,20 +15,3 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
adapter_maps: dict[str, type[WeightAdapterBase]] = {
"LoRA": LoRAAdapter,
"LoHa": LoHaAdapter,
"LoKr": LoKrAdapter,
"OFT": OFTAdapter,
## We disable not implemented algo for now
# "GLoRA": GLoRAAdapter,
# "BOFT": BOFTAdapter,
}
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters",
"adapter_maps",
] + [a.__name__ for a in adapters]

View File

@@ -12,20 +12,12 @@ class WeightAdapterBase:
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
@classmethod
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
weight: The original weight tensor to be modified.
*args: Additional arguments for configuration, such as rank, alpha etc.
"""
raise NotImplementedError
def calculate_weight(
self,
weight,
@@ -41,22 +33,10 @@ class WeightAdapterBase:
class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
def __init__(self):
super().__init__()
def __call__(self, w):
"""
w: The original weight tensor to be modified.
"""
raise NotImplementedError
def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")
def move_to(self, device):
self.to(device)
return self.passive_memory_usage()
# [TODO] Collaborate with LoRA training PR #7032
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
@@ -122,54 +102,3 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def tucker_weight_from_conv(up, down, mid):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
"""
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
"""
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@@ -3,120 +3,7 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
class LohaDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
# Create trainable parameters
self.hada_w1_a = torch.nn.Parameter(w1a)
self.hada_w1_b = torch.nn.Parameter(w1b)
self.hada_w2_a = torch.nn.Parameter(w2a)
self.hada_w2_b = torch.nn.Parameter(w2b)
self.use_tucker = False
if t1 is not None and t2 is not None:
self.use_tucker = True
self.hada_t1 = torch.nn.Parameter(t1)
self.hada_t2 = torch.nn.Parameter(t2)
else:
# Keep the attributes for consistent access
self.hada_t1 = None
self.hada_t2 = None
# Store rank and non-trainable alpha
self.rank = w1b.shape[0]
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
from .base import WeightAdapterBase, weight_decompose
class LoHaAdapter(WeightAdapterBase):
@@ -126,25 +13,6 @@ class LoHaAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
def to_train(self):
return LohaDiff(self.weights)
@classmethod
def load(
cls,

View File

@@ -3,77 +3,7 @@ from typing import Optional
import torch
import comfy.model_management
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
self.w1_rebuild = True
self.ranka = rank_a
if lokr_w2_a is not None:
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
if lokr_t2 is not None:
self.use_tucker = True
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
self.w2_rebuild = True
self.rankb = rank_b
if lokr_w1 is not None:
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
self.w1_rebuild = False
if lokr_w2 is not None:
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
self.w2_rebuild = False
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
@property
def w1(self):
if self.w1_rebuild:
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
else:
return self.lokr_w1
@property
def w2(self):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return w2 * (self.alpha / self.rankb)
else:
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
return w + diff.reshape(w.shape).to(w)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
from .base import WeightAdapterBase, weight_decompose
class LoKrAdapter(WeightAdapterBase):
@@ -83,20 +13,6 @@ class LoKrAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
@classmethod
def load(
cls,

View File

@@ -3,56 +3,7 @@ from typing import Optional
import torch
import comfy.model_management
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
pad_tensor_to_shape,
tucker_weight_from_conv,
)
class LoraDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
mat1, mat2, alpha, mid, dora_scale, reshape = weights
out_dim, rank = mat1.shape[0], mat1.shape[1]
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
self.lora_down = layer(in_dim, rank, bias=False)
self.lora_up.weight.data.copy_(mat1)
self.lora_down.weight.data.copy_(mat2)
if mid is not None:
self.lora_mid = layer(mid, rank, bias=False)
self.lora_mid.weight.data.copy_(mid)
else:
self.lora_mid = None
self.rank = rank
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
if self.lora_mid is None:
diff = self.lora_up.weight @ self.lora_down.weight
else:
diff = tucker_weight_from_conv(
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
)
scale = self.alpha / self.rank
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
class LoRAAdapter(WeightAdapterBase):
@@ -62,21 +13,6 @@ class LoRAAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
def to_train(self):
return LoraDiff(self.weights)
@classmethod
def load(
cls,
@@ -96,7 +32,6 @@ class LoRAAdapter(WeightAdapterBase):
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
qwen_default_lora = "{}.lora_B.default.weight".format(x)
A_name = None
if regular_lora in lora.keys():
@@ -123,10 +58,6 @@ class LoRAAdapter(WeightAdapterBase):
A_name = transformers_lora
B_name = "{}.lora_linear_layer.down.weight".format(x)
mid_name = None
elif qwen_default_lora in lora.keys():
A_name = qwen_default_lora
B_name = "{}.lora_A.default.weight".format(x)
mid_name = None
if A_name is not None:
mid = None

View File

@@ -3,58 +3,7 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
self.oft_blocks = torch.nn.Parameter(blocks)
if rescale is not None:
self.rescale = torch.nn.Parameter(rescale)
self.rescaled = True
else:
self.rescaled = False
self.block_num, self.block_size, _ = blocks.shape
self.constraint = float(alpha)
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
I = torch.eye(self.block_size, device=self.oft_blocks.device)
## generate r
# for Q = -Q^T
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
normed_q = q
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
## Apply chunked matmul on weight
_, *shape = w.shape
org_weight = w.to(dtype=r.dtype)
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
# Init R=0, so add I on it to ensure the output of step0 is original model output
weight = torch.einsum(
"k n m, k n ... -> k m ...",
r,
org_weight,
).flatten(0, 1)
if self.rescaled:
weight = self.rescale * weight
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
from .base import WeightAdapterBase, weight_decompose
class OFTAdapter(WeightAdapterBase):
@@ -64,18 +13,6 @@ class OFTAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
return OFTDiff(
(block, None, alpha, None)
)
def to_train(self):
return OFTDiff(self.weights)
@classmethod
def load(
cls,
@@ -123,8 +60,6 @@ class OFTAdapter(WeightAdapterBase):
blocks = v[0]
rescale = v[1]
alpha = v[2]
if alpha is None:
alpha = 0
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)

View File

@@ -1,69 +0,0 @@
"""
Feature flags module for ComfyUI WebSocket protocol negotiation.
This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
}
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
) -> Any:
"""
Get a feature flag value for a specific connection.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
default: Default value if feature not found
Returns:
Feature value or default if not found
"""
if sid not in sockets_metadata:
return default
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
"""
Check if a connection supports a specific feature.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
Returns:
Boolean indicating if feature is supported
"""
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.
Returns:
Dictionary of server feature flags
"""
return SERVER_FEATURE_FLAGS.copy()

View File

@@ -1,86 +0,0 @@
#!/usr/bin/env python3
"""
Script to generate .pyi stub files for the synchronous API wrappers.
This allows generating stubs without running the full ComfyUI application.
"""
import os
import sys
import logging
import importlib
# Add ComfyUI to path so we can import modules
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
from comfy_api.version_list import supported_versions
def generate_stubs_for_module(module_name: str) -> None:
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
try:
# Import the module
module = importlib.import_module(module_name)
# Check if module has ComfyAPISync (the sync wrapper)
if hasattr(module, "ComfyAPISync"):
# Module already has a sync class
api_class = getattr(module, "ComfyAPI", None)
sync_class = getattr(module, "ComfyAPISync")
if api_class:
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
)
elif hasattr(module, "ComfyAPI"):
# Module only has async API, need to create sync wrapper first
from comfy_api.internal.async_to_sync import create_sync_class
api_class = getattr(module, "ComfyAPI")
sync_class = create_sync_class(api_class)
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
)
except Exception as e:
logging.error(f"Failed to generate stub for {module_name}: {e}")
import traceback
traceback.print_exc()
def main():
"""Main function to generate all API stub files."""
logging.basicConfig(level=logging.INFO)
logging.info("Starting stub generation...")
# Dynamically get module names from supported_versions
api_modules = []
for api_class in supported_versions:
# Extract module name from the class
module_name = api_class.__module__
if module_name not in api_modules:
api_modules.append(module_name)
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
# Generate stubs for each module
for module_name in api_modules:
generate_stubs_for_module(module_name)
logging.info("Stub generation complete!")
if __name__ == "__main__":
main()

View File

@@ -1,16 +1,8 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input import (
ImageInput,
AudioInput,
MaskInput,
LatentInput,
VideoInput,
)
from .basic_types import ImageInput, AudioInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"MaskInput",
"LatentInput",
"VideoInput",
]

View File

@@ -1,14 +1,20 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input.basic_types import (
ImageInput,
AudioInput,
MaskInput,
LatentInput,
)
import torch
from typing import TypedDict
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int
__all__ = [
"ImageInput",
"AudioInput",
"MaskInput",
"LatentInput",
]

View File

@@ -1,6 +1,55 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input.video_types import VideoInput
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
__all__ = [
"VideoInput",
]
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

View File

@@ -1,7 +1,7 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@@ -1,2 +1,303 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._input_impl.video_types import * # noqa: F403
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.input import AudioInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.input import VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)

View File

@@ -1,150 +0,0 @@
# Internal infrastructure for ComfyAPI
from .api_registry import (
ComfyAPIBase as ComfyAPIBase,
ComfyAPIWithVersion as ComfyAPIWithVersion,
register_versions as register_versions,
get_all_versions as get_all_versions,
)
import asyncio
from dataclasses import asdict
from typing import Callable, Optional
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
"""Return the *callable* override of `name` visible on `cls`, or None if every
implementation up to (and including) `base` is the placeholder defined on `base`.
If base is not provided, it will assume cls has a GET_BASE_CLASS
"""
if base is None:
if not hasattr(cls, "GET_BASE_CLASS"):
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
base = cls.GET_BASE_CLASS()
base_attr = getattr(base, name, None)
if base_attr is None:
return None
base_func = base_attr.__func__
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
if c is base: # reached the placeholder we're done
break
if name in c.__dict__: # first class that *defines* the attr
func = getattr(c, name).__func__
if func is not base_func: # real override
return getattr(cls, name) # bound to *cls*
return None
class _ComfyNodeInternal:
"""Class that all V3-based APIs inherit from for ComfyNode.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
@classmethod
def GET_NODE_INFO_V1(cls):
...
class _NodeOutputInternal:
"""Class that all V3-based APIs inherit from for NodeOutput.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
...
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
def is_class(obj):
'''
Returns True if is a class type.
Returns False if is a class instance.
'''
return isinstance(obj, type)
def copy_class(cls: type) -> type:
'''
Copy a class and its attributes.
'''
if cls is None:
return None
cls_dict = {
k: v for k, v in cls.__dict__.items()
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
}
# new class
new_cls = type(
cls.__name__,
(cls,),
cls_dict
)
# metadata preservation
new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__
return new_cls
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
# NOTE: this was ai generated and validated by hand
def shallow_clone_class(cls, new_name=None):
'''
Shallow clone a class while preserving super() functionality.
'''
new_name = new_name or f"{cls.__name__}Clone"
# Include the original class in the bases to maintain proper inheritance
new_bases = (cls,) + cls.__bases__
return type(new_name, new_bases, dict(cls.__dict__))
# NOTE: this was ai generated and validated by hand
def lock_class(cls):
'''
Lock a class so that its top-levelattributes cannot be modified.
'''
# Locked instance __setattr__
def locked_instance_setattr(self, name, value):
raise AttributeError(
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
)
# Locked metaclass
class LockedMeta(type(cls)):
def __setattr__(cls_, name, value):
raise AttributeError(
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
)
# Rebuild class with locked behavior
locked_dict = dict(cls.__dict__)
locked_dict['__setattr__'] = locked_instance_setattr
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
def make_locked_method_func(type_obj, func, class_clone):
"""
Returns a function that, when called with **inputs, will execute:
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
Supports both synchronous and asynchronous methods.
"""
locked_class = lock_class(class_clone)
method = getattr(type_obj, func).__func__
# Check if the original method is async
if asyncio.iscoroutinefunction(method):
async def wrapped_async_func(**inputs):
return await method(locked_class, **inputs)
return wrapped_async_func
else:
def wrapped_func(**inputs):
return method(locked_class, **inputs)
return wrapped_func

View File

@@ -1,39 +0,0 @@
from typing import Type, List, NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
class ComfyAPIBase(ProxiedSingleton):
def __init__(self):
pass
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: Type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
"""
Parses a version string into a packaging_version.Version object.
Raises ValueError if the version string is invalid.
"""
if version_str == "latest":
return packaging_version.parse("9999999.9999999.9999999")
return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""
return registered_versions

View File

@@ -1,987 +0,0 @@
import asyncio
import concurrent.futures
import contextvars
import functools
import inspect
import logging
import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args
class TypeTracker:
"""Tracks types discovered during stub generation for automatic import generation."""
def __init__(self):
self.discovered_types = {} # type_name -> (module, qualname)
self.builtin_types = {
"Any",
"Dict",
"List",
"Optional",
"Tuple",
"Union",
"Set",
"Sequence",
"cast",
"NamedTuple",
"str",
"int",
"float",
"bool",
"None",
"bytes",
"object",
"type",
"dict",
"list",
"tuple",
"set",
}
self.already_imported = (
set()
) # Track types already imported to avoid duplicates
def track_type(self, annotation):
"""Track a type annotation and record its module/import info."""
if annotation is None or annotation is type(None):
return
# Skip builtins and typing module types we already import
type_name = getattr(annotation, "__name__", None)
if type_name and (
type_name in self.builtin_types or type_name in self.already_imported
):
return
# Get module and qualname
module = getattr(annotation, "__module__", None)
qualname = getattr(annotation, "__qualname__", type_name or "")
# Skip types from typing module (they're already imported)
if module == "typing":
return
# Skip UnionType and GenericAlias from types module as they're handled specially
if module == "types" and type_name in ("UnionType", "GenericAlias"):
return
if module and module not in ["builtins", "__main__"]:
# Store the type info
if type_name:
self.discovered_types[type_name] = (module, qualname)
def get_imports(self, main_module_name: str) -> list[str]:
"""Generate import statements for all discovered types."""
imports = []
imports_by_module = {}
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
# Skip types from the main module (they're already imported)
if main_module_name and module == main_module_name:
continue
if module not in imports_by_module:
imports_by_module[module] = []
if type_name not in imports_by_module[module]: # Avoid duplicates
imports_by_module[module].append(type_name)
# Generate import statements
for module, types in sorted(imports_by_module.items()):
if len(types) == 1:
imports.append(f"from {module} import {types[0]}")
else:
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
return imports
class AsyncToSyncConverter:
"""
Provides utilities to convert async classes to sync classes with proper type hints.
"""
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
_thread_pool_lock = threading.Lock()
_thread_pool_initialized = False
@classmethod
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
"""Get or create the shared thread pool with proper thread-safe initialization."""
# Fast path - check if already initialized without acquiring lock
if cls._thread_pool_initialized:
assert cls._thread_pool is not None, "Thread pool should be initialized"
return cls._thread_pool
# Slow path - acquire lock and create pool if needed
with cls._thread_pool_lock:
if not cls._thread_pool_initialized:
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix="async_to_sync_"
)
cls._thread_pool_initialized = True
# This should never be None at this point, but add assertion for type checker
assert cls._thread_pool is not None
return cls._thread_pool
@classmethod
def run_async_in_thread(cls, coro_func, *args, **kwargs):
"""
Run an async function in a separate thread from the thread pool.
Blocks until the async function completes.
Properly propagates contextvars between threads and manages event loops.
"""
# Capture current context - this includes all context variables
context = contextvars.copy_context()
# Store the result and any exception that occurs
result_container: dict = {"result": None, "exception": None}
# Function that runs in the thread pool
def run_in_thread():
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Create the coroutine within the context
async def run_with_context():
# The coroutine function might access context variables
return await coro_func(*args, **kwargs)
# Run the coroutine with the captured context
# This ensures all context variables are available in the async function
result = context.run(loop.run_until_complete, run_with_context())
result_container["result"] = result
except Exception as e:
# Store the exception to re-raise in the calling thread
result_container["exception"] = e
finally:
# Ensure event loop is properly closed to prevent warnings
try:
# Cancel any remaining tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Run the loop briefly to handle cancellations
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass # Ignore errors during cleanup
# Close the event loop
loop.close()
# Clear the event loop from the thread
asyncio.set_event_loop(None)
# Submit to thread pool and wait for result
thread_pool = cls.get_thread_pool()
future = thread_pool.submit(run_in_thread)
future.result() # Wait for completion
# Re-raise any exception that occurred in the thread
if result_container["exception"] is not None:
raise result_container["exception"]
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a new class with synchronous versions of all async methods.
Args:
async_class: The async class to convert
thread_pool_size: Size of thread pool to use
Returns:
A new class with sync versions of all async methods
"""
sync_class_name = "ComfyAPISyncStub"
cls.get_thread_pool(thread_pool_size)
# Create a proper class with docstrings and proper base classes
sync_class_dict = {
"__doc__": async_class.__doc__,
"__module__": async_class.__module__,
"__qualname__": sync_class_name,
"__orig_class__": async_class, # Store original class for typing references
}
# Create __init__ method
def __init__(self, *args, **kwargs):
self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items():
if hasattr(self._async_instance, attr_name):
# Attribute exists on the instance
attr = getattr(self._async_instance, attr_name)
# Check if this attribute needs a sync wrapper
if hasattr(attr, "__class__"):
from comfy_api.internal.singleton import ProxiedSingleton
if isinstance(attr, ProxiedSingleton):
# Create a sync version of this attribute
try:
sync_attr_class = cls.create_sync_class(attr.__class__)
# Create instance of the sync wrapper with the async instance
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = attr
setattr(self, attr_name, sync_attr)
except Exception:
# If we can't create a sync version, keep the original
setattr(self, attr_name, attr)
else:
# Not async, just copy the reference
setattr(self, attr_name, attr)
else:
# Attribute doesn't exist, but is annotated - create it
# This handles cases like execution: Execution
if isinstance(attr_type, type):
# Check if the type is defined as an inner class
if hasattr(async_class, attr_type.__name__):
inner_class = getattr(async_class, attr_type.__name__)
from comfy_api.internal.singleton import ProxiedSingleton
# Create an instance of the inner class
try:
# For ProxiedSingleton classes, get or create the singleton instance
if issubclass(inner_class, ProxiedSingleton):
async_instance = inner_class.get_instance()
else:
async_instance = inner_class()
# Create sync wrapper
sync_attr_class = cls.create_sync_class(inner_class)
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = async_instance
setattr(self, attr_name, sync_attr)
# Also set on the async instance for consistency
setattr(self._async_instance, attr_name, async_instance)
except Exception as e:
logging.warning(
f"Failed to create instance for {attr_name}: {e}"
)
# Handle other instance attributes that might not be annotated
for name, attr in inspect.getmembers(self._async_instance):
if name.startswith("_") or hasattr(self, name):
continue
# If attribute is an instance of a class, and that class is defined in the original class
# we need to check if it needs a sync wrapper
if isinstance(attr, object) and not isinstance(
attr, (str, int, float, bool, list, dict, tuple)
):
from comfy_api.internal.singleton import ProxiedSingleton
if isinstance(attr, ProxiedSingleton):
# Create a sync version of this nested class
try:
sync_attr_class = cls.create_sync_class(attr.__class__)
# Create instance of the sync wrapper with the async instance
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = attr
setattr(self, name, sync_attr)
except Exception:
# If we can't create a sync version, keep the original
setattr(self, name, attr)
sync_class_dict["__init__"] = __init__
# Process methods from the async class
for name, method in inspect.getmembers(
async_class, predicate=inspect.isfunction
):
if name.startswith("_"):
continue
# Extract the actual return type from a coroutine
if inspect.iscoroutinefunction(method):
# Create sync version of async method with proper signature
@functools.wraps(method)
def sync_method(self, *args, _method_name=name, **kwargs):
async_method = getattr(self._async_instance, _method_name)
return AsyncToSyncConverter.run_async_in_thread(
async_method, *args, **kwargs
)
# Add to the class dict
sync_class_dict[name] = sync_method
else:
# For regular methods, create a proxy method
@functools.wraps(method)
def proxy_method(self, *args, _method_name=name, **kwargs):
method = getattr(self._async_instance, _method_name)
return method(*args, **kwargs)
# Add to the class dict
sync_class_dict[name] = proxy_method
# Handle property access
for name, prop in inspect.getmembers(
async_class, lambda x: isinstance(x, property)
):
def make_property(name, prop_obj):
def getter(self):
value = getattr(self._async_instance, name)
if inspect.iscoroutinefunction(value):
def sync_fn(*args, **kwargs):
return AsyncToSyncConverter.run_async_in_thread(
value, *args, **kwargs
)
return sync_fn
return value
def setter(self, value):
setattr(self._async_instance, name, value)
return property(getter, setter if prop_obj.fset else None)
sync_class_dict[name] = make_property(name, prop)
# Create the class
sync_class = type(sync_class_name, (object,), sync_class_dict)
return sync_class
@classmethod
def _format_type_annotation(
cls, annotation, type_tracker: Optional[TypeTracker] = None
) -> str:
"""Convert a type annotation to its string representation for stub files."""
if (
annotation is inspect.Parameter.empty
or annotation is inspect.Signature.empty
):
return "Any"
# Handle None type
if annotation is type(None):
return "None"
# Track the type if we have a tracker
if type_tracker:
type_tracker.track_type(annotation)
# Try using typing.get_origin/get_args for Python 3.8+
try:
origin = get_origin(annotation)
args = get_args(annotation)
if origin is not None:
# Track the origin type
if type_tracker:
type_tracker.track_type(origin)
# Get the origin name
origin_name = getattr(origin, "__name__", str(origin))
if "." in origin_name:
origin_name = origin_name.split(".")[-1]
# Special handling for types.UnionType (Python 3.10+ pipe operator)
# Convert to old-style Union for compatibility
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
origin_name = "Union"
# Format arguments recursively
if args:
formatted_args = []
for arg in args:
# Track each type in the union
if type_tracker:
type_tracker.track_type(arg)
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
return f"{origin_name}[{', '.join(formatted_args)}]"
else:
return origin_name
except (AttributeError, TypeError):
# Fallback for older Python versions or non-generic types
pass
# Handle generic types the old way for compatibility
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
origin = annotation.__origin__
origin_name = (
origin.__name__
if hasattr(origin, "__name__")
else str(origin).split("'")[1]
)
# Format each type argument
args = []
for arg in annotation.__args__:
args.append(cls._format_type_annotation(arg, type_tracker))
return f"{origin_name}[{', '.join(args)}]"
# Handle regular types with __name__
if hasattr(annotation, "__name__"):
return annotation.__name__
# Handle special module types (like types from typing module)
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
# For types like typing.Literal, typing.TypedDict, etc.
return annotation.__qualname__
# Last resort: string conversion with cleanup
type_str = str(annotation)
# Clean up common patterns more robustly
if type_str.startswith("<class '") and type_str.endswith("'>"):
type_str = type_str[8:-2] # Remove "<class '" and "'>"
# Remove module prefixes for common modules
for prefix in ["typing.", "builtins.", "types."]:
if type_str.startswith(prefix):
type_str = type_str[len(prefix) :]
# Handle special cases
if type_str in ("_empty", "inspect._empty"):
return "None"
# Fix NoneType (this should rarely be needed now)
if type_str == "NoneType":
return "None"
return type_str
@classmethod
def _extract_coroutine_return_type(cls, annotation):
"""Extract the actual return type from a Coroutine annotation."""
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
return annotation.__args__[2]
return annotation
@classmethod
def _format_parameter_default(cls, default_value) -> str:
"""Format a parameter's default value for stub files."""
if default_value is inspect.Parameter.empty:
return ""
elif default_value is None:
return " = None"
elif isinstance(default_value, bool):
return f" = {default_value}"
elif default_value == {}:
return " = {}"
elif default_value == []:
return " = []"
else:
return f" = {default_value}"
@classmethod
def _format_method_parameters(
cls,
sig: inspect.Signature,
skip_self: bool = True,
type_hints: Optional[dict] = None,
type_tracker: Optional[TypeTracker] = None,
) -> str:
"""Format method parameters for stub files."""
params = []
if type_hints is None:
type_hints = {}
for i, (param_name, param) in enumerate(sig.parameters.items()):
if i == 0 and param_name == "self" and skip_self:
params.append("self")
else:
# Get type annotation from type hints if available, otherwise from signature
annotation = type_hints.get(param_name, param.annotation)
type_str = cls._format_type_annotation(annotation, type_tracker)
# Get default value
default_str = cls._format_parameter_default(param.default)
# Combine parameter parts
if annotation is inspect.Parameter.empty:
params.append(f"{param_name}: Any{default_str}")
else:
params.append(f"{param_name}: {type_str}{default_str}")
return ", ".join(params)
@classmethod
def _generate_method_signature(
cls,
method_name: str,
method,
is_async: bool = False,
type_tracker: Optional[TypeTracker] = None,
) -> str:
"""Generate a complete method signature for stub files."""
sig = inspect.signature(method)
# Try to get evaluated type hints to resolve string annotations
try:
from typing import get_type_hints
type_hints = get_type_hints(method)
except Exception:
# Fallback to empty dict if we can't get type hints
type_hints = {}
# For async methods, extract the actual return type
return_annotation = type_hints.get('return', sig.return_annotation)
if is_async and inspect.iscoroutinefunction(method):
return_annotation = cls._extract_coroutine_return_type(return_annotation)
# Format parameters with type hints
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
# Format return type
return_type = cls._format_type_annotation(return_annotation, type_tracker)
if return_annotation is inspect.Signature.empty:
return_type = "None"
return f"def {method_name}({params_str}) -> {return_type}: ..."
@classmethod
def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
# Add standard typing imports
imports.append(
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
)
# Add imports from the original module
if async_class.__module__ != "builtins":
module = inspect.getmodule(async_class)
additional_types = []
if module:
# Check if module has __all__ defined
module_all = getattr(module, "__all__", None)
for name, obj in sorted(inspect.getmembers(module)):
if isinstance(obj, type):
# Skip if __all__ is defined and this name isn't in it
# unless it's already been tracked as used in type annotations
if module_all is not None and name not in module_all:
# Check if this type was actually used in annotations
if name not in type_tracker.discovered_types:
continue
# Check for NamedTuple
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
additional_types.append(name)
# Mark as already imported
type_tracker.already_imported.add(name)
# Check for Enum
elif issubclass(obj, Enum) and name != "Enum":
additional_types.append(name)
# Mark as already imported
type_tracker.already_imported.add(name)
if additional_types:
type_imports = ", ".join([async_class.__name__] + additional_types)
imports.append(f"from {async_class.__module__} import {type_imports}")
else:
imports.append(
f"from {async_class.__module__} import {async_class.__name__}"
)
# Add imports for all discovered types
# Pass the main module name to avoid duplicate imports
imports.extend(
type_tracker.get_imports(main_module_name=async_class.__module__)
)
# Add base module import if needed
if hasattr(inspect.getmodule(async_class), "__name__"):
module_name = inspect.getmodule(async_class).__name__
if "." in module_name:
base_module = module_name.split(".")[0]
# Only add if not already importing from it
if not any(imp.startswith(f"from {base_module}") for imp in imports):
imports.append(f"import {base_module}")
return imports
@classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
# Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr))
elif (
hasattr(async_class, "__annotations__")
and name in async_class.__annotations__
):
annotation = async_class.__annotations__[name]
if isinstance(annotation, type):
class_attributes.append((name, annotation))
return class_attributes
@classmethod
def _generate_inner_class_stub(
cls,
name: str,
attr: Type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
"""Generate stub for an inner class."""
stub_lines = []
stub_lines.append(f"{indent}class {name}Sync:")
# Add docstring if available
if hasattr(attr, "__doc__") and attr.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
)
# Add __init__ if it exists
if hasattr(attr, "__init__"):
try:
init_method = getattr(attr, "__init__")
init_sig = inspect.signature(init_method)
# Try to get type hints
try:
from typing import get_type_hints
init_hints = get_type_hints(init_method)
except Exception:
init_hints = {}
# Format parameters
params_str = cls._format_method_parameters(
init_sig, type_hints=init_hints, type_tracker=type_tracker
)
# Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(
init_method.__doc__, f"{indent} "
)
)
stub_lines.append(
f"{indent} def __init__({params_str}) -> None: ..."
)
except (ValueError, TypeError):
stub_lines.append(
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
)
# Add methods to the inner class
has_methods = False
for method_name, method in sorted(
inspect.getmembers(attr, predicate=inspect.isfunction)
):
if method_name.startswith("_"):
continue
has_methods = True
try:
# Add method docstring if available (before the method signature)
if method.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
)
method_sig = cls._generate_method_signature(
method_name, method, is_async=True, type_tracker=type_tracker
)
stub_lines.append(f"{indent} {method_sig}")
except (ValueError, TypeError):
stub_lines.append(
f"{indent} def {method_name}(self, *args, **kwargs): ..."
)
if not has_methods:
stub_lines.append(f"{indent} pass")
return stub_lines
@classmethod
def _format_docstring_for_stub(
cls, docstring: str, indent: str = " "
) -> list[str]:
"""Format a docstring for inclusion in a stub file with proper indentation."""
if not docstring:
return []
# First, dedent the docstring to remove any existing indentation
dedented = textwrap.dedent(docstring).strip()
# Split into lines
lines = dedented.split("\n")
# Build the properly indented docstring
result = []
result.append(f'{indent}"""')
for line in lines:
if line.strip(): # Non-empty line
result.append(f"{indent}{line}")
else: # Empty line
result.append("")
result.append(f'{indent}"""')
return result
@classmethod
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
"""Post-process stub content to fix any remaining issues."""
processed = []
for line in stub_content:
# Skip processing imports
if line.startswith(("from ", "import ")):
processed.append(line)
continue
# Fix method signatures missing return types
if (
line.strip().startswith("def ")
and line.strip().endswith(": ...")
and ") -> " not in line
):
# Add -> None for methods without return annotation
line = line.replace(": ...", " -> None: ...")
processed.append(line)
return processed
@classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
try:
# Only generate stub if we can determine module path
if async_class.__module__ == "__main__":
return
module = inspect.getmodule(async_class)
if not module:
return
module_path = module.__file__
if not module_path:
return
# Create stub file path in a 'generated' subdirectory
module_dir = os.path.dirname(module_path)
stub_dir = os.path.join(module_dir, "generated")
# Ensure the generated directory exists
os.makedirs(stub_dir, exist_ok=True)
module_name = os.path.basename(module_path)
if module_name.endswith(".py"):
module_name = module_name[:-3]
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
# Create a type tracker for this stub generation
type_tracker = TypeTracker()
stub_content = []
# We'll generate imports after processing all methods to capture all types
# Leave a placeholder for imports
imports_placeholder_index = len(stub_content)
stub_content.append("") # Will be replaced with imports later
# Class definition
stub_content.append(f"class {sync_class.__name__}:")
# Docstring
if async_class.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(async_class.__doc__, " ")
)
# Generate __init__
try:
init_method = async_class.__init__
init_signature = inspect.signature(init_method)
# Try to get type hints for __init__
try:
from typing import get_type_hints
init_hints = get_type_hints(init_method)
except Exception:
init_hints = {}
# Format parameters
params_str = cls._format_method_parameters(
init_signature, type_hints=init_hints, type_tracker=type_tracker
)
# Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(init_method.__doc__, " ")
)
stub_content.append(f" def __init__({params_str}) -> None: ...")
except (ValueError, TypeError):
stub_content.append(
" def __init__(self, *args, **kwargs) -> None: ..."
)
stub_content.append("") # Add newline after __init__
# Get class attributes
class_attributes = cls._get_class_attributes(async_class)
# Generate inner classes
for name, attr in class_attributes:
inner_class_stub = cls._generate_inner_class_stub(
name, attr, type_tracker=type_tracker
)
stub_content.extend(inner_class_stub)
stub_content.append("") # Add newline after the inner class
# Add methods to the main class
processed_methods = set() # Keep track of methods we've processed
for name, method in sorted(
inspect.getmembers(async_class, predicate=inspect.isfunction)
):
if name.startswith("_") or name in processed_methods:
continue
processed_methods.add(name)
try:
method_sig = cls._generate_method_signature(
name, method, is_async=True, type_tracker=type_tracker
)
# Add docstring if available (before the method signature for proper formatting)
if method.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(method.__doc__, " ")
)
stub_content.append(f" {method_sig}")
stub_content.append("") # Add newline after each method
except (ValueError, TypeError):
# If we can't get the signature, just add a simple stub
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
stub_content.append("") # Add newline
# Add properties
for name, prop in sorted(
inspect.getmembers(async_class, lambda x: isinstance(x, property))
):
stub_content.append(" @property")
stub_content.append(f" def {name}(self) -> Any: ...")
if prop.fset:
stub_content.append(f" @{name}.setter")
stub_content.append(
f" def {name}(self, value: Any) -> None: ..."
)
stub_content.append("") # Add newline after each property
# Add placeholders for the nested class instances
# Check the actual attribute names from class annotations and attributes
attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes:
# If the class type matches the annotated type
if (
attr_type == class_type
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
or (isinstance(attr_type, str) and attr_type == class_name)
):
attribute_mappings[class_name] = attr_name
# Remove the extra checking - annotations should be sufficient
# Add the attribute declarations with proper names
for class_name, class_type in class_attributes:
# Check if there's a mapping from annotation
attr_name = attribute_mappings.get(class_name, class_name)
# Use the annotation name if it exists, even if the attribute doesn't exist yet
# This is because the attribute might be created at runtime
stub_content.append(f" {attr_name}: {class_name}Sync")
stub_content.append("") # Add a final newline
# Now generate imports with all discovered types
imports = cls._generate_imports(async_class, type_tracker)
# Deduplicate imports while preserving order
seen = set()
unique_imports = []
for imp in imports:
if imp not in seen:
seen.add(imp)
unique_imports.append(imp)
else:
logging.warning(f"Duplicate import detected: {imp}")
# Replace the placeholder with actual imports
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
unique_imports
)
# Post-process stub content
stub_content = cls._post_process_stub_content(stub_content)
# Write stub file
with open(sync_stub_path, "w") as f:
f.write("\n".join(stub_content))
logging.info(f"Generated stub file: {sync_stub_path}")
except Exception as e:
# If stub generation fails, log the error but don't break the main functionality
logging.error(
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
)
import traceback
logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a sync version of an async class
Args:
async_class: The async class to convert
thread_pool_size: Size of thread pool to use
Returns:
A new class with sync versions of all async methods
"""
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)

View File

@@ -1,33 +0,0 @@
from typing import Type, TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""
if cls not in SingletonMetaclass._instances:
SingletonMetaclass._instances[cls] = super(
SingletonMetaclass, cls
).__call__(*args, **kwargs)
return cls._instances[cls]
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
def __init__(self):
super().__init__()

View File

@@ -1,124 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
from comfy.cli_args import args
import numpy as np
class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
class Execution(ProxiedSingleton):
async def set_progress(
self,
value: float,
max_value: float,
node_id: str | None = None,
preview_image: Image.Image | ImageInput | None = None,
ignore_size_limit: bool = False,
) -> None:
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
executing_context = get_executing_context()
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
if node_id is None:
raise ValueError("node_id must be provided if not in executing context")
# Convert preview_image to PreviewImageTuple if needed
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
if to_display is not None:
# First convert to PIL Image if needed
if isinstance(to_display, ImageInput):
# Convert ImageInput (torch.Tensor) to PIL Image
# Handle tensor shape [B, H, W, C] -> get first image if batch
tensor = to_display
if len(tensor.shape) == 4:
tensor = tensor[0]
# Convert to numpy array and scale to 0-255
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
to_display = Image.fromarray(image_np)
if isinstance(to_display, Image.Image):
# Detect image format from PIL Image
image_format = to_display.format if to_display.format else "JPEG"
# Use None for preview_size if ignore_size_limit is True
preview_size = None if ignore_size_limit else args.preview_size
to_display = (image_format, to_display, preview_size)
get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
async def get_node_list(self) -> list[type[io.ComfyNode]]:
"""
Returns a list of nodes that this extension provides.
"""
class Input:
Image = ImageInput
Audio = AudioInput
Mask = MaskInput
Latent = LatentInput
Video = VideoInput
class InputImpl:
VideoFromFile = VideoFromFile
VideoFromComponents = VideoFromComponents
class Types:
VideoCodec = VideoCodec
VideoContainer = VideoContainer
VideoComponents = VideoComponents
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
"ComfyExtension",
]

View File

@@ -1,10 +0,0 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
"MaskInput",
"LatentInput",
]

View File

@@ -1,42 +0,0 @@
import torch
from typing import TypedDict, List, Optional
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
MaskInput = torch.Tensor
"""
A mask in format [B, H, W] where B is the batch size
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int
class LatentInput(TypedDict):
"""
TypedDict representing latent input.
"""
samples: torch.Tensor
"""
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
H is the height, and W is the width.
"""
noise_mask: Optional[MaskInput]
"""
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[List[int]]

View File

@@ -1,85 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.
Returns:
Either a file path (str) or a BytesIO object that can be opened with av.
Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
# Default implementation - subclasses should override for better performance
source = self.get_stream_source()
with av.open(source, mode="r") as container:
return container.format.name

View File

@@ -1,7 +0,0 @@
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@@ -1,324 +0,0 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)

File diff suppressed because it is too large Load Diff

View File

@@ -1,72 +0,0 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

View File

@@ -1,463 +0,0 @@
from __future__ import annotations
import json
import os
import random
from io import BytesIO
from typing import Type
import av
import numpy as np
import torch
try:
import torchaudio
TORCH_AUDIO_AVAILABLE = True
except ImportError:
TORCH_AUDIO_AVAILABLE = False
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
def __init__(self, filename: str, subfolder: str, type: FolderType):
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
@property
def filename(self) -> str:
return self["filename"]
@property
def subfolder(self) -> str:
return self["subfolder"]
@property
def type(self) -> FolderType:
return FolderType(self["type"])
class SavedImages(_UIOutput):
"""A UI output class to represent one or more saved images, potentially animated."""
def __init__(self, results: list[SavedResult], is_animated: bool = False):
super().__init__()
self.results = results
self.is_animated = is_animated
def as_dict(self) -> dict:
data = {"images": self.results}
if self.is_animated:
data["animated"] = (True,)
return data
class SavedAudios(_UIOutput):
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
def __init__(self, results: list[SavedResult]):
super().__init__()
self.results = results
def as_dict(self) -> dict:
return {"audio": self.results}
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
if folder_type == FolderType.input:
return folder_paths.get_input_directory()
if folder_type == FolderType.output:
return folder_paths.get_output_directory()
return folder_paths.get_temp_directory()
class ImageSaveHelper:
"""A helper class with static methods to handle image saving and metadata."""
@staticmethod
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
"""Converts a single torch tensor to a PIL Image."""
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add(
b"comf",
"prompt".encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
after_idat=True,
)
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add(
b"comf",
x.encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
after_idat=True,
)
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
return exif_data
if cls.hidden.prompt is not None:
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
if cls.hidden.extra_pnginfo is not None:
inital_exif_tag = 0x010F # EXIF 0x010f = Make
for key, value in cls.hidden.extra_pnginfo.items():
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
inital_exif_tag -= 1
return exif_data
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
results = []
metadata = ImageSaveHelper._create_png_metadata(cls)
for batch_number, image_tensor in enumerate(images):
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
compress_level=compress_level,
)
)
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
file = f"{filename}_{counter:05}_.png"
save_path = os.path.join(full_output_folder, file)
pil_images[0].save(
save_path,
pnginfo=metadata,
compress_level=compress_level,
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
compress_level=compress_level,
)
return SavedImages([result], is_animated=len(images) > 1)
@staticmethod
def save_animated_webp(
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedResult:
"""Saves a batch of images as a single animated WebP."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
file = f"{filename}_{counter:05}_.webp"
pil_images[0].save(
os.path.join(full_output_folder, file),
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
exif=pil_exif,
lossless=lossless,
quality=quality,
method=method,
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedImages:
"""Saves an animated WebP and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_webp(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=method,
)
return SavedImages([result], is_animated=len(images) > 1)
class AudioSaveHelper:
"""A helper class with static methods to handle audio saving and metadata."""
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
@staticmethod
def save_audio(
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type)
)
metadata = {}
if not args.disable_metadata and cls is not None:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(AudioSaveHelper._OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
if not TORCH_AUDIO_AVAILABLE:
raise Exception("torchaudio is not available; cannot resample audio.")
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, "wb") as f:
f.write(output_buffer.getbuffer())
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
AudioSaveHelper.save_audio(
audio,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
format=format,
quality=quality,
)
)
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
compress_level=1,
)
self.animated = animated
def as_dict(self):
return {
"images": self.values,
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs)
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
format="flac",
quality="128k",
)
def as_dict(self) -> dict:
return {"audio": self.values}
class PreviewVideo(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
return {"images": self.values, "animated": (True,)}
class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
def as_dict(self):
return {"result": [self.model_file, self.camera_info]}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText

Some files were not shown because too many files have changed in this diff Show More