Compare commits

..

91 Commits

Author SHA1 Message Date
Alexander Piskun
917177e821 move assets related stuff to "app/assets" folder (#10184) 2025-10-03 11:53:26 -07:00
Alexander Piskun
fd6ac0a765 drop PgSQL 14, unite migration for SQLite and PgSQL (#10165) 2025-10-03 11:34:06 -07:00
Alexander Piskun
94941c50b3 move alembic_db inside app folder (#10163) 2025-10-02 15:01:16 -07:00
Jedrzej Kosinski
fbba2e59e5 Satisfy ruff 2025-09-26 20:39:23 -07:00
Jedrzej Kosinski
adccfb2dfd Remove populate_db_with_asset from load_torch_file for now, as nothing yet uses the hashes 2025-09-26 20:33:46 -07:00
Jedrzej Kosinski
9f4c0f3afe Merge branch 'master' into asset-management 2025-09-26 20:24:25 -07:00
Jedrzej Kosinski
ca39552954 Merge branch 'master' into asset-management 2025-09-24 23:44:57 -07:00
Jedrzej Kosinski
4dd843d36f Merge branch 'master' into asset-management 2025-09-18 14:08:20 -07:00
Jedrzej Kosinski
46fdd636de Merge pull request #9545 from bigcat88/asset-management
[Assets] Initial implementation
2025-09-18 14:07:18 -07:00
bigcat88
283cd27bdc final adjustments 2025-09-18 10:05:32 +03:00
bigcat88
1a37d1476d refactor(6): fully batched initial scan 2025-09-17 20:29:29 +03:00
bigcat88
f9602457d6 optimization: initial scan speed(batching metadata[filename]) 2025-09-17 16:47:27 +03:00
bigcat88
85ef08449d optimization: initial scan speed(batching tags) 2025-09-17 14:08:57 +03:00
bigcat88
5b6810a2c6 fixed hash calculation during model loading in ComfyUI 2025-09-17 13:25:56 +03:00
bigcat88
621faaa195 refactor(5): use less DB queries to create seed asset 2025-09-17 10:46:21 +03:00
bigcat88
d0aa64d57b refactor(4): use one query to init DB with all tags for assets 2025-09-16 21:18:18 +03:00
bigcat88
677a0e2508 refactor(3): unite logic for Asset fast check 2025-09-16 20:29:50 +03:00
bigcat88
31ec744317 refactor(2)/fix: skip double checking the existing files during fast check 2025-09-16 19:50:21 +03:00
bigcat88
a336c7c165 refactor(1): use general fast_asset_file_check helper for fast check 2025-09-16 19:19:18 +03:00
bigcat88
77332d3054 optimization: fast scan: commit to the DB in chunks 2025-09-16 14:21:40 +03:00
bigcat88
24a95f5ca4 removed default scanning of "input" and "output" folders; added separate endpoint for test suite. 2025-09-16 11:28:29 +03:00
bigcat88
0be513b213 fix: escape "_" symbol in all other places 2025-09-15 20:26:48 +03:00
bigcat88
f1fb7432a0 fix+test: escape "_" symbol in assets filtering 2025-09-15 19:19:47 +03:00
bigcat88
f3cf99d10c fix+test: escape "_" symbol in tags filtering 2025-09-15 17:29:27 +03:00
bigcat88
5f187fe6fb optimization: make list_unhashed_candidates_under_prefixes single-query instead of N+1 2025-09-15 12:46:35 +03:00
bigcat88
025fc49b4e optimization: DB Queries (Tags) 2025-09-15 10:26:13 +03:00
bigcat88
7becb84341 fixed tests on SQLite file 2025-09-14 23:01:17 +03:00
bigcat88
dda31de690 rework: AssetInfo.name is only a display name 2025-09-14 21:53:44 +03:00
bigcat88
1d970382f0 added final tests 2025-09-14 20:02:28 +03:00
bigcat88
a2fc2bbae4 corrected formatting 2025-09-14 18:12:00 +03:00
bigcat88
a7f2546558 fix: use ".rowcount" instead of ".returning" on SQLite 2025-09-14 17:55:02 +03:00
bigcat88
6cfa94ec58 fixed metadata[filename] feature + new tests for this 2025-09-14 16:28:14 +03:00
bigcat88
a2ec1f7637 simplify code 2025-09-14 15:31:42 +03:00
bigcat88
0b795dc7a7 removed non-needed code 2025-09-14 15:14:24 +03:00
bigcat88
47f7c7ee8c rework + add test for concurrent AssetInfo delete 2025-09-14 15:08:29 +03:00
bigcat88
cdd8d16075 +2 tests for checking Asset downloading logic 2025-09-14 14:57:24 +03:00
bigcat88
37b81e6658 fixed new PgSQL bug 2025-09-14 14:30:38 +03:00
bigcat88
975650060f concurrency upload test + fixed 2 related bugs 2025-09-14 09:39:23 +03:00
bigcat88
4a713654cd added more tests for the Assets logic 2025-09-14 09:10:59 +03:00
bigcat88
9b8e88ba6e added more tests for the Assets logic 2025-09-13 20:09:45 +03:00
bigcat88
bb9ed04758 global refactoring; add support for Assets without the computed hash 2025-09-13 16:39:08 +03:00
Alexander Piskun
934377ac1e removed currently unnecessary "asset_locations" functionality 2025-09-12 14:46:22 +03:00
Alexander Piskun
3c9bf39c20 Merge pull request #1 from bigcat88/asset-management-ci
fix bugs + GH CI tests
2025-09-10 16:31:11 +03:00
bigcat88
0df1ccac6f GitHub CI test for Assets 2025-09-10 16:22:22 +03:00
bigcat88
72548a8ac4 added additional tests; sorted tests 2025-09-10 10:39:55 +03:00
bigcat88
6eaed072c7 add some logic tests 2025-09-10 09:51:06 +03:00
bigcat88
a9096f6c97 removed non-needed code, fix tests, +1 new test 2025-09-09 20:54:11 +03:00
bigcat88
964de8a8ad add more list_assets tests + fix one found bug 2025-09-09 20:35:18 +03:00
bigcat88
1886f10e19 add download tests 2025-09-09 19:30:58 +03:00
bigcat88
357193f7b5 fixed metadata filtering + tests 2025-09-09 19:12:11 +03:00
bigcat88
0ef73e95fd fixed validation error + more tests 2025-09-09 16:02:39 +03:00
bigcat88
faa1e4de17 fixed another test 2025-09-09 15:17:03 +03:00
bigcat88
dfb5703d40 feat: remove Asset when there is no references left + bugfixes + more tests 2025-09-09 15:10:07 +03:00
bigcat88
0e9de2b7c9 feat: add first test 2025-09-08 20:43:45 +03:00
bigcat88
e3311c9229 feat: support for in-memory SQLite databases 2025-09-08 18:15:09 +03:00
bigcat88
3fa0fc496c fix: use UPSERT to eliminate rare race condition during ingesting many small files in parallel 2025-09-08 18:13:32 +03:00
bigcat88
6282d495ca corrected detection of missing files for assets 2025-09-07 22:08:38 +03:00
bigcat88
b8ef9bb92c add detection of the missing files for existing assets 2025-09-07 16:49:39 +03:00
bigcat88
2d9be462d3 add support for assets duplicates 2025-09-06 19:22:51 +03:00
bigcat88
789a62ce35 assume that DB packages always present; refactoring & cleanup 2025-09-06 17:44:01 +03:00
bigcat88
84384ca0b4 temporary restore ModelManager 2025-09-05 23:02:26 +03:00
bigcat88
ce270ba090 added Assets Autoscan feature 2025-09-05 17:46:09 +03:00
bigcat88
bf8363ec87 always autofill "filename" in the metadata 2025-08-29 19:48:42 +03:00
bigcat88
6b86be320a use UUID instead of autoincrement Integer for Assets ID field 2025-08-28 08:22:54 +03:00
bigcat88
bdf4ba24ce removed not needed "assets.updated_at" column 2025-08-27 21:58:17 +03:00
bigcat88
871e41aec6 removed not needed "refcount" column 2025-08-27 21:36:31 +03:00
bigcat88
eb7008a4d3 removed not used "added_by" column 2025-08-27 21:26:35 +03:00
bigcat88
0379eff0b5 allow Upload Asset endpoint to accept hash (as documentation requires) 2025-08-27 21:18:26 +03:00
bigcat88
026b7f209c add "--multi-user" support 2025-08-27 19:47:55 +03:00
bigcat88
7c1b0be496 add Get Asset endpoint 2025-08-27 09:58:12 +03:00
bigcat88
6fade5da38 add AssetsResolver support 2025-08-26 20:58:04 +03:00
bigcat88
a763cbd39d add upload asset endpoint 2025-08-25 16:35:29 +03:00
bigcat88
09dabf95bc refactoring: use the same code for "scan task" and realtime DB population 2025-08-25 13:31:56 +03:00
bigcat88
d7464e9e73 implemented assets scaner 2025-08-24 19:29:21 +03:00
bigcat88
a82577f64a auto-creation of tags and fixed population DB when cloned asset is already present 2025-08-24 16:36:01 +03:00
bigcat88
f2ea0bc22c added create_asset_from_hash endpoint 2025-08-24 14:15:21 +03:00
bigcat88
0755e5320a remove timezone; download asset, delete asset endpoints 2025-08-24 12:36:20 +03:00
bigcat88
8d46bec951 use Pydantic for output; finished Tags endpoints 2025-08-24 11:02:30 +03:00
bigcat88
5c1b5973ac dev: refactor; populate models in more nodes; use Pydantic in endpoints for input validation 2025-08-23 20:14:22 +03:00
bigcat88
f92307cd4c dev: Everything is Assets 2025-08-23 19:21:52 +03:00
Jedrzej Kosinski
c708d0a433 Merge branch 'master' into asset-management 2025-08-18 12:12:15 -07:00
Jedrzej Kosinski
1aa089e0b6 More progress on brainstorming code for asset management for models 2025-08-12 17:59:16 -07:00
Jedrzej Kosinski
f032c1a50a Brainstorming abstraction for asset management stuff 2025-08-11 22:25:42 -07:00
Jedrzej Kosinski
3089936a2c Merge branch 'master' into pysssss-model-db 2025-08-11 14:09:21 -07:00
Jedrzej Kosinski
cd679129e3 Merge branch 'master' into pysssss-model-db 2025-08-07 21:12:30 -07:00
pythongosssss
d7062277a7 fix bad merge 2025-08-03 16:40:27 +01:00
pythongosssss
54cf14cbbb Merge remote-tracking branch 'origin/master' into pysssss-model-db 2025-08-03 16:36:49 +01:00
pythongosssss
7d5160f92c Tidy 2025-06-01 15:45:15 +01:00
pythongosssss
7f7b3f1695 tidy 2025-06-01 15:41:00 +01:00
pythongosssss
9da6aca0d0 Add additional db model metadata fields and model downloading function 2025-06-01 15:32:13 +01:00
pythongosssss
1cb3c98947 Implement database & model hashing 2025-06-01 15:32:02 +01:00
223 changed files with 21097 additions and 18597 deletions

View File

@@ -1,27 +0,0 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
pause

View File

@@ -1,3 +0,0 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,3 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,3 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -8,15 +8,13 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
`--disable-all-custom-nodes` command line argument.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:

View File

@@ -1,21 +0,0 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

View File

@@ -1,58 +0,0 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@@ -1,78 +0,0 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu130)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu130"
python_minor: "13"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu126"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu126"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu126"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@@ -21,28 +21,3 @@ jobs:
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@@ -2,53 +2,17 @@
name: "Release Stable Version"
on:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
cu:
description: 'CUDA version'
required: true
type: string
default: "cu129"
default: "129"
python_minor:
description: 'Python minor version'
required: true
@@ -59,21 +23,7 @@ on:
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs:
package_comfy_windows:
@@ -92,15 +42,15 @@ jobs:
id: cache
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
cu${{ inputs.cu }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv cu${{ inputs.cu }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf ${{ inputs.cache_tag }}_python_deps.tar
tar xf cu${{ inputs.cu }}_python_deps.tar
pwd
ls
@@ -115,19 +65,12 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
cd ..
@@ -142,18 +85,14 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@@ -162,9 +101,10 @@ jobs:
ls
- name: Upload binaries to release
uses: softprops/action-gh-release@v2
uses: svenstaro/upload-release-action@v2
with:
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
draft: true
overwrite_files: true

173
.github/workflows/test-assets.yml vendored Normal file
View File

@@ -0,0 +1,173 @@
name: Asset System Tests
on:
push:
paths:
- 'app/**'
- 'tests-assets/**'
- '.github/workflows/test-assets.yml'
- 'requirements.txt'
pull_request:
branches: [master]
workflow_dispatch:
permissions:
contents: read
env:
PIP_DISABLE_PIP_VERSION_CHECK: '1'
PYTHONUNBUFFERED: '1'
jobs:
sqlite:
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
sqlite_mode: ['memory', 'file']
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for SQLite
id: setdb
shell: bash
run: |
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
else
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
mkdir -p "$(dirname "$DBFILE")"
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
fi
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn
postgres:
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
pgsql: ['16', '18']
services:
postgres:
image: postgres:${{ matrix.pgsql }}
env:
POSTGRES_DB: assets
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U postgres -d assets"
--health-interval 10s
--health-timeout 5s
--health-retries 12
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
pip install greenlet psycopg
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for PostgreSQL
shell: bash
run: |
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn

View File

@@ -21,15 +21,14 @@ jobs:
fail-fast: false
matrix:
# os: [macos, linux, windows]
# os: [macos, linux]
os: [linux]
python_version: ["3.10", "3.11", "3.12"]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
@@ -74,15 +73,14 @@ jobs:
strategy:
fail-fast: false
matrix:
# os: [macos, linux]
os: [linux]
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""

View File

@@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "130"
default: "129"
python_minor:
description: 'python minor version'
@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "6"
# push:
# branches:
# - master
@@ -56,8 +56,7 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@@ -1,64 +0,0 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause

View File

@@ -81,7 +81,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..

View File

@@ -1,168 +0,0 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@@ -112,11 +112,10 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Releases a new stable version (e.g., v0.7.0)
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
@@ -173,20 +172,10 @@ There is a portable standalone build for Windows that should work for running on
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@@ -202,11 +191,7 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
### Instructions:
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
Git clone this repo.
@@ -215,36 +200,18 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux)
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
### Intel GPUs (Windows and Linux)
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command:
@@ -254,15 +221,19 @@ This is the command to install the Pytorch xpu nightly which might have some per
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
#### Troubleshooting
@@ -293,6 +264,12 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@@ -3,7 +3,7 @@
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
script_location = app/alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time

View File

@@ -2,13 +2,12 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from app.assets.database.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,

View File

@@ -0,0 +1,175 @@
"""initial assets schema
Revision ID: 0001_assets
Revises:
Create Date: 2025-08-20 00:00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

4
app/assets/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .api.routes import register_assets_system
from .scanner import sync_seed_assets
__all__ = ["sync_seed_assets", "register_assets_system"]

225
app/assets/_helpers.py Normal file
View File

@@ -0,0 +1,225 @@
import contextlib
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Optional, Sequence
import folder_paths
from .api import schemas_in
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> Optional[str]:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None
def new_scan_id(root: schemas_in.RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

544
app/assets/api/routes.py Normal file
View File

@@ -0,0 +1,544 @@
import contextlib
import logging
import os
import urllib.parse
import uuid
from typing import Optional
from aiohttp import web
from pydantic import ValidationError
import folder_paths
from ... import user_manager
from .. import manager, scanner
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
USER_MANAGER: Optional[user_manager.UserManager] = None
LOGGER = logging.getLogger(__name__)
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = await manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
qp = request.rel_url.query
query_dict = {}
if "include_tags" in qp:
query_dict["include_tags"] = qp.getall("include_tags")
if "exclude_tags" in qp:
query_dict["exclude_tags"] = qp.getall("exclude_tags")
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
v = qp.get(k)
if v is not None:
query_dict[k] = v
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = await manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = await manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = await manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: Optional[str] = None
tags_raw: list[str] = []
provided_name: Optional[str] = None
user_metadata_raw: Optional[str] = None
provided_hash: Optional[str] = None
provided_hash_exists: Optional[bool] = None
file_written = 0
tmp_path: Optional[str] = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = await manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = await manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = await manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
async def set_asset_preview(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.SetPreviewBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.set_asset_preview(
asset_info_id=asset_info_id,
preview_asset_id=body.preview_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (PermissionError, ValueError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = await manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
LOGGER.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as ve:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
status=400,
)
result = await manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
try:
await scanner.sync_seed_assets(body.roots)
except Exception:
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response({"synced": True, "roots": body.roots}, status=200)
@ROUTES.post("/api/assets/scan/schedule")
async def schedule_asset_scan(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
states = await scanner.schedule_scans(body.roots)
return web.json_response(states.model_dump(mode="json"), status=202)
@ROUTES.get("/api/assets/scan")
async def get_asset_scan_status(request: web.Request) -> web.Response:
root = request.query.get("root", "").strip().lower()
states = scanner.current_statuses()
if root in {"models", "input", "output"}:
states = [s for s in states.scans if s.root == root] # type: ignore
states = schemas_out.AssetScanStatusResponse(scans=states)
return web.json_response(states.model_dump(mode="json"), status=200)
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})

View File

@@ -0,0 +1,297 @@
import json
import uuid
from typing import Any, Literal, Optional
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: Optional[str] = None
# Accept either a JSON string (query param) or a dict
metadata_filter: Optional[dict[str, Any]] = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class UpdateAssetBody(BaseModel):
name: Optional[str] = None
tags: Optional[list[str]] = None
user_metadata: Optional[dict[str, Any]] = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.tags is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, tags, user_metadata.")
if self.tags is not None:
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
raise ValueError("Field 'tags' must be an array of strings.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
v = v.strip()
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: Optional[str] = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
preview_url: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: Optional[str]
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: Optional[str] = None
created_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class AssetScanError(BaseModel):
path: str
message: str
at: Optional[str] = Field(None, description="ISO timestamp")
class AssetScanStatus(BaseModel):
scan_id: str
root: Literal["models", "input", "output"]
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
scheduled_at: Optional[str] = None
started_at: Optional[str] = None
finished_at: Optional[str] = None
discovered: int = 0
processed: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)
class AssetScanStatusResponse(BaseModel):
scans: list[AssetScanStatus] = Field(default_factory=list)

View File

@@ -0,0 +1,25 @@
from .bulk_ops import seed_from_paths_batch
from .escape_like import escape_like_prefix
from .fast_check import fast_asset_file_check
from .filters import apply_metadata_filter, apply_tag_filters
from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv
from .tags import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
remove_missing_tag_for_asset_id,
)
__all__ = [
"apply_tag_filters",
"apply_metadata_filter",
"escape_like_prefix",
"fast_asset_file_check",
"is_scalar",
"project_kv",
"ensure_tags_exist",
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"seed_from_paths_batch",
"visible_owner_clause",
]

View File

@@ -0,0 +1,230 @@
import os
import uuid
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
from ..timeutil import utcnow
MAX_BIND_PARAMS = 800
async def seed_from_paths_batch(
session: AsyncSession,
*,
specs: Sequence[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
dialect = session.bind.dialect.name
if dialect not in ("sqlite", "postgresql"):
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
await session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
else:
ins_state = (
d_pg.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
else:
ins_info = (
d_pg.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
async def bulk_insert_tags_and_meta(
session: AsyncSession,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
dialect = session.bind.dialect.name
if tag_rows:
if dialect == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
await session.execute(ins_links, chunk)
if meta_rows:
if dialect == "sqlite":
ins_meta = (
d_sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif dialect == "postgresql":
ins_meta = (
d_pg.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
await session.execute(ins_meta, chunk)
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))

View File

@@ -0,0 +1,7 @@
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape

View File

@@ -0,0 +1,19 @@
import os
from typing import Optional
def fast_asset_file_check(
*,
mtime_db: Optional[int],
size_db: Optional[int],
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True

View File

@@ -0,0 +1,87 @@
from typing import Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Optional[Sequence[str]],
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@@ -0,0 +1,12 @@
import sqlalchemy as sa
from ..models import AssetInfo
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@@ -0,0 +1,64 @@
from decimal import Decimal
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -0,0 +1,90 @@
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def add_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def remove_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
) -> None:
await session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@@ -0,0 +1,251 @@
import uuid
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
from .timeutil import utcnow
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list["AssetCacheState"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Optional[Asset]] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -0,0 +1,57 @@
from .content import (
check_fs_asset_exists_quick,
compute_hash_and_dedup_for_cache_state,
ingest_fs_asset,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset,
touch_asset_infos_by_fs_path,
)
from .info import (
add_tags_to_asset_info,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_tags,
list_asset_infos_page,
list_tags_with_usage,
remove_tags_from_asset_info,
replace_asset_info_metadata_projection,
set_asset_info_preview,
set_asset_info_tags,
touch_asset_info_by_id,
update_asset_info_full,
)
from .queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
get_cache_state_by_asset_id,
list_cache_states_by_asset_id,
pick_best_live_path,
)
__all__ = [
# queries
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
"get_cache_state_by_asset_id",
"list_cache_states_by_asset_id",
"pick_best_live_path",
# info
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
"update_asset_info_full", "replace_asset_info_metadata_projection",
"touch_asset_info_by_id", "delete_asset_info_by_id",
"add_tags_to_asset_info", "remove_tags_from_asset_info",
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content
"check_fs_asset_exists_quick",
"redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
"list_cache_states_with_asset_under_prefixes",
]

View File

@@ -0,0 +1,721 @@
import contextlib
import logging
import os
from datetime import datetime
from typing import Any, Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from ..._helpers import compute_relative_filename
from ...storage import hashing as hashing_mod
from ..helpers import (
ensure_tags_exist,
escape_like_prefix,
remove_missing_tag_for_asset_id,
)
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
from .info import replace_asset_info_metadata_projection
from .queries import list_cache_states_by_asset_id, pick_best_live_path
async def check_fs_asset_exists_quick(
session: AsyncSession,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
locator = os.path.abspath(file_path)
stmt = (
sa.select(sa.literal(True))
.select_from(AssetCacheState)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(
AssetCacheState.file_path == locator,
Asset.hash.isnot(None),
AssetCacheState.needs_verify.is_(False),
)
.limit(1)
)
conds = []
if mtime_ns is not None:
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
if conds:
stmt = stmt.where(*conds)
return (await session.execute(stmt)).first() is not None
async def redirect_all_references_then_delete_asset(
session: AsyncSession,
*,
duplicate_asset_id: str,
canonical_asset_id: str,
) -> None:
"""
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
- Otherwise, simply repoint the AssetInfo.asset_id.
- Always retarget AssetCacheState rows.
- Finally delete the duplicate Asset row.
"""
if duplicate_asset_id == canonical_asset_id:
return
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
dup_infos = (
await session.execute(
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
)
).unique().scalars().all()
for info in dup_infos:
# Try to find an existing collision on canonical
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == canonical_asset_id,
AssetInfo.owner_id == info.owner_id,
AssetInfo.name == info.name,
)
.limit(1)
)
).unique().scalars().first()
if existing:
merged_meta = dict(existing.user_metadata or {})
other_meta = info.user_metadata or {}
for k, v in other_meta.items():
if k not in merged_meta:
merged_meta[k] = v
if merged_meta != (existing.user_metadata or {}):
await replace_asset_info_metadata_projection(
session,
asset_info_id=existing.id,
user_metadata=merged_meta,
)
existing_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
)
).all()
}
from_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
)
).all()
}
to_add = sorted(from_tags - existing_tags)
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
now = utcnow()
session.add_all([
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
for t in to_add
])
await session.flush()
if existing.preview_id is None and info.preview_id is not None:
existing.preview_id = info.preview_id
if info.last_access_time and (
existing.last_access_time is None or info.last_access_time > existing.last_access_time
):
existing.last_access_time = info.last_access_time
existing.updated_at = utcnow()
await session.flush()
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
await session.delete(info)
await session.flush()
else:
# Simple retarget
info.asset_id = canonical_asset_id
info.updated_at = utcnow()
await session.flush()
# 2) Repoint cache states and previews
await session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.asset_id == duplicate_asset_id)
.values(asset_id=canonical_asset_id)
)
await session.execute(
sa.update(AssetInfo)
.where(AssetInfo.preview_id == duplicate_asset_id)
.values(preview_id=canonical_asset_id)
)
# 3) Remove duplicate Asset
dup = await session.get(Asset, duplicate_asset_id)
if dup:
await session.delete(dup)
await session.flush()
async def compute_hash_and_dedup_for_cache_state(
session: AsyncSession,
*,
state_id: int,
) -> Optional[str]:
"""
Compute hash for the given cache state, deduplicate, and settle verify cases.
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
"""
state = await session.get(AssetCacheState, state_id)
if not state:
return None
path = state.file_path
try:
if not os.path.isfile(path):
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
asset = await session.get(Asset, state.asset_id)
await session.delete(state)
await session.flush()
if asset and asset.hash is None:
remaining = (
await session.execute(
sa.select(sa.func.count())
.select_from(AssetCacheState)
.where(AssetCacheState.asset_id == asset.id)
)
).scalar_one()
if int(remaining or 0) == 0:
await session.delete(asset)
await session.flush()
else:
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
return None
digest = await hashing_mod.blake3_hash(path)
new_hash = f"blake3:{digest}"
st = os.stat(path, follow_symlinks=True)
new_size = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
# Current asset of this state
this_asset = await session.get(Asset, state.asset_id)
# If the state got orphaned somehow (race), just reattach appropriately.
if not this_asset:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
state.asset_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
state.asset_id = new_asset.id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
await session.flush()
return state.asset_id
# 1) Seed asset case (hash is NULL): claim or merge into canonical
if this_asset.hash is None:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
# Merge seed asset into canonical (safe, collision-aware)
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# No canonical: try to claim the hash; handle races with a SAVEPOINT
try:
async with session.begin_nested():
this_asset.hash = new_hash
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
await session.flush()
except IntegrityError:
# Someone else claimed it concurrently; fetch canonical and merge
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# If we got here, the integrity error was not about hash uniqueness
raise
# Claimed successfully
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# 2) Verify case for hashed assets
if this_asset.hash == new_hash:
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
target_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
target_id = new_asset.id
state.asset_id = target_id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
await session.flush()
return target_id
except Exception:
raise
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
if session.bind.dialect.name == "postgresql":
stmt = (
sa.select(AssetCacheState.id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
.distinct(AssetCacheState.asset_id)
)
else:
first_id = sa.func.min(AssetCacheState.id).label("first_id")
stmt = (
sa.select(first_id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.group_by(AssetCacheState.asset_id)
.order_by(first_id.asc())
)
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
async def list_verify_candidates_under_prefixes(
session: AsyncSession, *, prefixes: Sequence[str]
) -> Union[list[int], Sequence[int]]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
return (
await session.execute(
sa.select(AssetCacheState.id)
.where(AssetCacheState.needs_verify.is_(True))
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: str = "",
preview_id: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
dialect = session.bind.dialect.name
if preview_id:
if not await session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
if dialect == "sqlite":
res = await session.execute(
d_sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
elif dialect == "postgresql":
res = await session.execute(
d_pg.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(
index_elements=[Asset.hash],
index_where=Asset.__table__.c.hash.isnot(None),
)
.returning(Asset.id)
)
inserted_id = res.scalar_one_or_none()
if inserted_id:
out["asset_created"] = True
asset = await session.get(Asset, inserted_id)
else:
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
res = await session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = await session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
async with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
await session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
await session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
await session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
await ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
await session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
file_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
locator = os.path.abspath(file_path)
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(AssetCacheState)
.where(
AssetCacheState.asset_id == AssetInfo.asset_id,
AssetCacheState.file_path == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
await session.execute(stmt.values(last_access_time=ts))
async def list_cache_states_with_asset_under_prefixes(
session: AsyncSession,
*,
prefixes: Sequence[str],
) -> list[tuple[AssetCacheState, Optional[str], int]]:
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
if not prefixes:
return []
conds = []
for p in prefixes:
if not p:
continue
base = os.path.abspath(p)
if not base.endswith(os.sep):
base = base + os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
if not conds:
return []
rows = (
await session.execute(
select(AssetCacheState, Asset.hash, Asset.size_bytes)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).all()
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
try:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
if not primary_path:
return
new_filename = compute_relative_filename(primary_path)
if not new_filename:
return
infos = (
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
).scalars().all()
for info in infos:
current_meta = info.user_metadata or {}
if current_meta.get("filename") == new_filename:
continue
updated = dict(current_meta)
updated["filename"] = new_filename
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
except Exception:
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)

View File

@@ -0,0 +1,586 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, noload
from ..._helpers import compute_relative_filename, normalize_tags
from ..helpers import (
apply_metadata_filter,
apply_tag_filters,
ensure_tags_exist,
escape_like_prefix,
project_kv,
visible_owner_clause,
)
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from ..timeutil import utcnow
from .queries import (
get_asset_by_hash,
list_cache_states_by_asset_id,
pick_best_live_path,
)
async def list_asset_infos_page(
session: AsyncSession,
*,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((await session.execute(count_stmt)).scalar_one() or 0)
infos = (await session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def fetch_asset_info_and_asset(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = await session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
async def fetch_asset_info_asset_and_tags(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (await session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
async def create_asset_info_for_existing_asset(
session: AsyncSession,
*,
asset_hash: str,
name: str,
user_metadata: Optional[dict] = None,
tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
async with session.begin_nested():
session.add(info)
await session.flush()
except IntegrityError:
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: str,
user_metadata: Optional[dict],
) -> None:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def touch_asset_info_by_id(
session: AsyncSession,
*,
asset_info_id: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
await session.execute(stmt.values(last_access_time=ts))
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((await session.execute(stmt)).rowcount or 0) > 0
async def add_tags_to_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
await ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
async with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
await session.flush()
except IntegrityError:
await nested.rollback()
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
async def remove_tags_from_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
await session.flush()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
async def list_tags_with_usage(
session: AsyncSession,
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (await session.execute(q.limit(limit).offset(offset))).all()
total = (await session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
async def set_asset_info_preview(
session: AsyncSession,
*,
asset_info_id: str,
preview_asset_id: Optional[str],
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not await session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
await session.flush()

View File

@@ -0,0 +1,76 @@
import os
from typing import Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
row = (
await session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
return (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (await session.execute(q)).first() is not None
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_id(
session: AsyncSession, *, asset_id: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path

View File

@@ -0,0 +1,6 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)

556
app/assets/manager.py Normal file
View File

@@ -0,0 +1,556 @@
import contextlib
import logging
import mimetypes
import os
from typing import Optional, Sequence
from comfy_api.internal import async_to_sync
from ..db import create_session
from ._helpers import (
ensure_within_base,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from .api import schemas_in, schemas_out
from .database.models import Asset
from .database.services import (
add_tags_to_asset_info,
asset_exists_by_hash,
asset_info_exists_for_asset_id,
check_fs_asset_exists_quick,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash,
get_asset_info_by_id,
get_asset_tags,
ingest_fs_asset,
list_asset_infos_page,
list_cache_states_by_asset_id,
list_tags_with_usage,
pick_best_live_path,
remove_tags_from_asset_info,
set_asset_info_preview,
touch_asset_info_by_id,
touch_asset_infos_by_fs_path,
update_asset_info_full,
)
from .storage import hashing
async def asset_exists(*, asset_hash: str) -> bool:
async with await create_session() as session:
return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
if not size_bytes:
return
async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
await session.commit()
return
asset_hash = hashing.blake3_hash_sync(abs_path)
async with await create_session() as session:
await ingest_fs_asset(
session,
asset_hash="blake3:" + asset_hash,
abs_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=file_name,
tag_origin="automatic",
tags=tags,
)
await session.commit()
async def list_assets(
*,
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
async with await create_session() as session:
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
async def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: Optional[str] = None,
owner_id: str = "",
expected_asset_hash: Optional[str] = None,
) -> schemas_out.AssetCreated:
try:
digest = await hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
async with await create_session() as session:
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = await create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
async with await create_session() as session:
result = await ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
async def update_asset(
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
await session.commit()
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
async def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: Optional[str],
owner_id: str = "",
) -> schemas_out.AssetDetail:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
await set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
await session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
await session.commit()
return False
if not delete_content_if_orphan or not asset_id:
await session.commit()
return True
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
await session.commit()
return True
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = await session.get(Asset, asset_id)
if asset_row is not None:
await session.delete(asset_row)
await session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
async def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower()
async with await create_session() as session:
asset = await get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = await create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
async def list_tags(
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
async with await create_session() as session:
rows, total = await list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
async def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
await session.commit()
return schemas_out.TagsAdd(**data)
async def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
await session.commit()
return schemas_out.TagsRemove(**data)
def _safe_sort_field(requested: Optional[str]) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str], fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback

501
app/assets/scanner.py Normal file
View File

@@ -0,0 +1,501 @@
import asyncio
import contextlib
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Literal, Optional
import sqlalchemy as sa
import folder_paths
from ..db import create_session
from ._helpers import (
collect_models_files,
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
list_tree,
new_scan_id,
prefixes_for_root,
ts_to_iso,
)
from .api import schemas_in, schemas_out
from .database.helpers import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
escape_like_prefix,
fast_asset_file_check,
remove_missing_tag_for_asset_id,
seed_from_paths_batch,
)
from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import (
compute_hash_and_dedup_for_cache_state,
list_cache_states_by_asset_id,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
)
LOGGER = logging.getLogger(__name__)
SLOW_HASH_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: schemas_in.RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
file_errors: list[dict] = field(default_factory=list)
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
def current_statuses() -> schemas_out.AssetScanStatusResponse:
scans = []
for root in schemas_in.ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
scans.append(_scan_progress_to_scan_status_model(prog))
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
results: list[ScanProgress] = []
for root in roots:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
RUNNING_TASKS[root] = asyncio.create_task(
_run_hash_verify_pipeline(root, prog, state),
name=f"asset-scan:{root}",
)
results.append(prog)
return _status_response_for(results)
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
t_total = time.perf_counter()
created = 0
skipped_existing = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as ex:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
ap = os.path.abspath(p)
if ap in existing_paths:
skipped_existing += 1
continue
try:
st = os.stat(ap, follow_symlinks=True)
except OSError:
continue
if not st.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(ap)
specs.append(
{
"abs_path": ap,
"size_bytes": st.st_size,
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(ap),
}
)
for t in tags:
tag_pool.add(t)
if not specs:
return
async with await create_session() as sess:
if tag_pool:
await ensure_tags_exist(sess, tag_pool, tag_type="user")
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
await sess.commit()
finally:
LOGGER.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_total,
created,
skipped_existing,
len(paths),
)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=ts_to_iso(progress.scheduled_at),
started_at=ts_to_iso(progress.started_at),
finished_at=ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
prog.status = "running"
prog.started_at = time.time()
try:
prefixes = prefixes_for_root(root)
await _fast_db_consistency_pass(root)
# collect candidates from DB
async with await create_session() as sess:
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
# dedupe: prioritize verification first
seen = set()
ordered: list[int] = []
for lst in (verify_ids, unhashed_ids):
for sid in lst:
if sid not in seen:
seen.add(sid)
ordered.append(sid)
prog.discovered = len(ordered)
# queue up work
for sid in ordered:
await state.queue.put(sid)
state.closed = True
_start_state_workers(root, prog, state)
await _await_state_workers_then_finish(root, prog, state)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
_append_error(prog, path="", message=str(exc))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None)
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
"""
Detect missing files quickly and toggle 'missing' tag per asset_id.
Rules:
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
- We consider ALL cache states of the asset (across roots) before tagging.
"""
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
elif root == "input":
bases = [folder_paths.get_input_directory()]
else:
bases = [folder_paths.get_output_directory()]
try:
async with await create_session() as sess:
# state + hash + size for the current root
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
# Track fast_ok within the scanned root and whether the asset is hashed
by_asset: dict[str, dict[str, bool]] = {}
for state, a_hash, size_db in rows:
aid = state.asset_id
acc = by_asset.get(aid)
if acc is None:
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
by_asset[aid] = acc
try:
if acc["hashed"]:
st = os.stat(state.file_path, follow_symlinks=True)
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
acc["any_fast_ok_here"] = True
except FileNotFoundError:
pass
except OSError as e:
_append_error(prog, path=state.file_path, message=str(e))
# Decide per asset, considering ALL its states (not just this root)
for aid, acc in by_asset.items():
try:
if not acc["hashed"]:
# Never tag seed assets as missing
continue
any_fast_ok_global = acc["any_fast_ok_here"]
if not any_fast_ok_global:
# Check other states outside this root
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
for st in others:
try:
any_fast_ok_global = fast_asset_file_check(
mtime_db=st.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(st.file_path, follow_symlinks=True),
)
except OSError:
continue
if any_fast_ok_global:
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
else:
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception as ex:
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
await sess.commit()
except Exception as e:
_append_error(prog, path="", message=f"reconcile failed: {e}")
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
if state.workers:
return
async def _worker(_wid: int):
while True:
sid = await state.queue.get()
try:
if sid is None:
return
try:
async with await create_session() as sess:
# Optional: fetch path for better error messages
st = await sess.get(AssetCacheState, sid)
try:
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
await sess.commit()
except Exception as e:
path = st.file_path if st else f"state:{sid}"
_append_error(prog, path=path, message=str(e))
raise
except Exception:
pass
finally:
prog.processed += 1
finally:
state.queue.task_done()
state.workers = [
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
for i in range(SLOW_HASH_CONCURRENCY)
]
async def _close_when_ready():
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_ready())
async def _await_state_workers_then_finish(
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
await _reconcile_missing_tags_for_root(root, prog)
prog.finished_at = time.time()
prog.status = "completed"
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"at": ts_to_iso(time.time()),
})
async def _fast_db_consistency_pass(
root: schemas_in.RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> Optional[set[str]]:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
async with await create_session() as sess:
rows = (
await sess.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = await sess.get(Asset, aid)
if asset:
await sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
await sess.commit()
return survivors if collect_existing_paths else None

View File

@@ -0,0 +1,72 @@
import asyncio
import os
from typing import IO, Union
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
"""Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = None
if hasattr(file_obj, "tell"):
orig_pos = file_obj.tell()
try:
if hasattr(file_obj, "seek"):
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if hasattr(file_obj, "seek") and orig_pos is not None:
file_obj.seek(orig_pos)
def blake3_hash_sync(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
if hasattr(fp, "read"):
return _hash_file_obj_sync(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
async def blake3_hash(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
return await asyncio.to_thread(_worker)

View File

@@ -1,112 +0,0 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

View File

@@ -1,14 +0,0 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

255
app/db.py Normal file
View File

@@ -0,0 +1,255 @@
import logging
import os
import shutil
from contextlib import asynccontextmanager
from typing import Optional
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from comfy.cli_args import args
LOGGER = logging.getLogger(__name__)
ENGINE: Optional[AsyncEngine] = None
SESSION: Optional[async_sessionmaker] = None
def _root_paths():
"""Resolve alembic.ini and migrations script folder."""
root_path = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
return config_path, scripts_path
def _absolutize_sqlite_url(db_url: str) -> str:
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
try:
u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
db_path: str = u.database or ""
if isinstance(db_path, str) and db_path.startswith("file:"):
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
"""
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
Returns: (normalized_url, is_memory)
"""
try:
u = make_url(db_url)
except Exception:
return db_url, False
if not u.drivername.startswith("sqlite"):
return db_url, False
db = u.database or ""
if db == ":memory:":
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
return str(u), True
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
if "uri=true" not in db:
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
return str(u), True
return str(u), False
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
"""Return the on-disk path for a SQLite URL, else None."""
try:
u = make_url(sync_url)
except Exception:
return None
if not u.drivername.startswith("sqlite"):
return None
db_path = u.database
if isinstance(db_path, str) and db_path.startswith("file:"):
return None # Not a real file if it is a URI like "file:...?"
return db_path
def _get_alembic_config(sync_url: str) -> Config:
"""Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
db_url = _absolutize_sqlite_url(db_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
if is_mem:
connect_args["uri"] = True
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
if not is_mem:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(database_url=db_url, connect_args=connect_args)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(database_url: str, connect_args: dict) -> None:
if database_url.find("postgresql+psycopg") == -1:
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(database_url)
driver = u.drivername
if not driver.startswith("sqlite+aiosqlite"):
raise ValueError(f"Unsupported DB driver: {driver}")
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
database_url = _absolutize_sqlite_url(database_url)
cfg = _get_alembic_config(database_url)
engine = create_engine(database_url, future=True, connect_args=connect_args)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head()
if target_rev is None:
LOGGER.warning("Alembic: no target revision found.")
return
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(database_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try:
shutil.copy(sqlite_path, backup_path)
except Exception as exc:
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
try:
command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path)
except Exception as re:
LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
LOGGER.exception("Error upgrading database, backup is not available.")
raise
def get_engine():
"""Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@@ -10,8 +10,7 @@ import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Dict, TypedDict, Optional
from aiohttp import web
from typing import TypedDict, Optional
from importlib.metadata import version
import requests
@@ -43,7 +42,6 @@ def get_installed_frontend_version():
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
@@ -65,7 +63,6 @@ def get_required_frontend_version():
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
@@ -199,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
@@ -206,39 +214,17 @@ class FrontendManager:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@@ -258,54 +244,16 @@ comfyui-frontend-package is not installed.
sys.exit(-1)
@classmethod
def template_asset_map(cls) -> Optional[Dict[str, str]]:
"""Return a mapping of template asset names to their absolute paths."""
try:
from comfyui_workflow_templates import (
get_asset_path,
iter_templates,
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
comfyui-workflow-templates is not installed.
Returns:
str: The path to the workflow templates directory.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
try:
template_entries = list(iter_templates())
except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}")
return None
asset_map: Dict[str, str] = {}
try:
for entry in template_entries:
for asset in entry.assets:
asset_map[asset.filename] = get_asset_path(
entry.template_id, asset.filename
)
except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}")
return None
if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?")
return None
return asset_map
@classmethod
def legacy_templates_path(cls) -> Optional[str]:
"""Return the legacy templates directory shipped inside the meta package."""
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@@ -324,7 +272,6 @@ comfyui-workflow-templates is not installed.
********** ERROR ***********
""".strip()
)
return None
@classmethod
def embedded_docs_path(cls) -> str:
@@ -342,11 +289,16 @@ comfyui-workflow-templates is not installed.
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
tuple[str, str, str]: A tuple containing (owner, repo, version).
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
@@ -363,18 +315,22 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initializes the frontend for the specified version.
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@@ -426,13 +382,17 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Args:
version_string (str): The version string to initialize the frontend with.
version_string (str): The version string specifying which frontend to use.
Returns:
str: The path of the initialized frontend.
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
@@ -441,17 +401,3 @@ comfyui-workflow-templates is not installed.
logging.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):
assets = cls.template_asset_map()
if not assets:
return None
async def serve_template(request: web.Request) -> web.StreamResponse:
rel_path = request.match_info.get("path", "")
target = assets.get(rel_path)
if target is None:
raise web.HTTPNotFound()
return web.FileResponse(target)
return serve_template

View File

@@ -1,112 +0,0 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@@ -105,7 +105,6 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -146,9 +145,7 @@ class PerformanceFeature(enum.Enum):
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
@@ -215,7 +212,8 @@ parser.add_argument(
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@@ -310,13 +310,11 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@@ -352,13 +350,12 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@@ -611,66 +611,6 @@ class HunyuanImage21Refiner(LatentFormat):
latent_dimensions = 3
scale_factor = 1.03682
def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z
class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[ 0.0568, -0.0521, -0.0131],
[ 0.0014, 0.0735, 0.0326],
[ 0.0186, 0.0531, -0.0138],
[-0.0031, 0.0051, 0.0288],
[ 0.0110, 0.0556, 0.0432],
[-0.0041, -0.0023, -0.0485],
[ 0.0530, 0.0413, 0.0253],
[ 0.0283, 0.0251, 0.0339],
[ 0.0277, -0.0372, -0.0093],
[ 0.0393, 0.0944, 0.1131],
[ 0.0020, 0.0251, 0.0037],
[-0.0017, 0.0012, 0.0234],
[ 0.0468, 0.0436, 0.0203],
[ 0.0354, 0.0439, -0.0233],
[ 0.0090, 0.0123, 0.0346],
[ 0.0382, 0.0029, 0.0217],
[ 0.0261, -0.0300, 0.0030],
[-0.0088, -0.0220, -0.0283],
[-0.0272, -0.0121, -0.0363],
[-0.0664, -0.0622, 0.0144],
[ 0.0414, 0.0479, 0.0529],
[ 0.0355, 0.0612, -0.0247],
[ 0.0147, 0.0264, 0.0174],
[ 0.0438, 0.0038, 0.0542],
[ 0.0431, -0.0573, -0.0033],
[-0.0162, -0.0211, -0.0406],
[-0.0487, -0.0295, -0.0393],
[ 0.0005, -0.0109, 0.0253],
[ 0.0296, 0.0591, 0.0353],
[ 0.0119, 0.0181, -0.0306],
[-0.0085, -0.0362, 0.0229],
[ 0.0005, -0.0106, 0.0242]
]
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@@ -23,6 +23,8 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@@ -35,6 +37,10 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@@ -67,8 +73,10 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@@ -83,7 +91,9 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@@ -91,6 +101,7 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@@ -1,15 +1,15 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
# TODO: remove this in a few months
SingleStreamBlock = None
DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut):
@@ -48,6 +48,124 @@ class Approximator(nn.Module):
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()

View File

@@ -11,12 +11,12 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
DoubleStreamBlock,
SingleStreamBlock,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@@ -90,7 +90,6 @@ class Chroma(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -99,7 +98,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)

View File

@@ -10,10 +10,12 @@ from torch import Tensor, nn
from einops import repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator,
)
from .layers import (
@@ -87,6 +89,7 @@ class ChromaRadiance(Chroma):
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -94,7 +97,6 @@ class ChromaRadiance(Chroma):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -107,7 +109,6 @@ class ChromaRadiance(Chroma):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
modulation=False,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)
@@ -188,15 +189,15 @@ class ChromaRadiance(Chroma):
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
else:
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
@@ -239,8 +240,17 @@ class ChromaRadiance(Chroma):
end = min(i + tile_size, num_patches)
# Slice the current tile from the input tensors
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
nerf_hidden_tile = nerf_hidden[:, i:end, :]
nerf_pixels_tile = nerf_pixels[:, i:end, :]
# Get the actual number of patches in this tile (can be smaller for the last tile)
num_patches_tile = nerf_hidden_tile.shape[1]
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)

View File

@@ -130,17 +130,13 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.modulation = modulation
if self.modulation:
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
@@ -151,9 +147,7 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
@@ -166,65 +160,46 @@ class DoubleStreamBlock(nn.Module):
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
del img_attn
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
@@ -245,7 +220,6 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
modulation=True,
dtype=None,
device=None,
operations=None
@@ -268,29 +242,19 @@ class SingleStreamBlock(nn.Module):
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else:
self.modulation = None
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
if self.modulation:
mod, _ = self.modulation(vec)
else:
mod = vec
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)

View File

@@ -7,8 +7,15 @@ import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
if pe is not None:
q, k = apply_rope(q, k, pe)
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
@@ -30,10 +37,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

View File

@@ -210,7 +210,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -222,22 +222,10 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
steps_h = h_len
steps_w = w_len
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@@ -253,7 +241,7 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0

View File

@@ -6,6 +6,7 @@ import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
@@ -41,8 +42,6 @@ class HunyuanVideoParams:
guidance_embed: bool
byt5: bool
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
class SelfAttentionRef(nn.Module):
@@ -158,10 +157,7 @@ class TokenRefiner(nn.Module):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
if x.dtype == torch.float16:
c = x.float().sum(dim=1) / x.shape[1]
else:
c = x.sum(dim=1) / x.shape[1]
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
@@ -200,15 +196,11 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
self.use_cond_type_embedding = params.use_cond_type_embedding
self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -274,18 +266,6 @@ class HunyuanVideo(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
# HunyuanVideo 1.5 specific modules
if self.vision_in_dim is not None:
from comfy.ldm.wan.model import MLPProj
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
else:
self.vision_in = None
if self.use_cond_type_embedding:
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
else:
self.cond_type_embedding = None
def forward_orig(
self,
img: Tensor,
@@ -296,7 +276,6 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor,
y: Tensor = None,
txt_byt5=None,
clip_fea=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
@@ -352,31 +331,12 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
if self.cond_type_embedding is not None:
self.cond_type_embedding.to(txt.device)
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
txt = txt + cond_emb.to(txt.dtype)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
else:
txt = torch.cat((txt, txt_byt5), dim=1)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
if clip_fea is not None:
txt_vision_states = self.vision_in(clip_fea)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -470,14 +430,14 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
@@ -485,5 +445,5 @@ class HunyuanVideo(nn.Module):
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out

View File

@@ -1,120 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class SRModel3DV2(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y
class Upsampler(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
self.up = nn.ModuleList()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_shortcut=False,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
def forward(self, z):
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# upsampling
for stage in self.up:
for blk in stage.block:
x = blk(x)
out = self.conv_out(F.silu(self.norm_out(x)))
return out
UPSAMPLERS = {
"720p": SRModel3DV2,
"1080p": Upsampler,
}
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()
def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@@ -1,43 +1,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -46,25 +14,23 @@ class RMS_norm(nn.Module):
self.gamma = nn.Parameter(torch.empty(shape))
def forward(self, x):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tds=True):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.tds = tds
self.gs = fct * ic // oc
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
def forward(self, x):
r1 = 2 if self.tds else 1
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tds and self.refiner_vae and conv_carry_in is None:
h = self.conv(x)
if self.tds:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@@ -72,7 +38,14 @@ class DnSmpl(nn.Module):
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1)
h = h[:, :, 1:, :, :]
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
@@ -80,49 +53,49 @@ class DnSmpl(nn.Module):
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
x = x[:, :, 1:, :, :]
xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
if h.shape[2] == 0:
return hf + xf
b, ci, frms, ht, wd = x.shape
nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
b, ci, frms, ht, wd = x.shape
nf = frms // r1
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = x.shape
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
if self.tds and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
return h + sc
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tus=True):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.tus = tus
self.rp = fct * oc // ic
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
def forward(self, x):
r1 = 2 if self.tus else 1
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
h = self.conv(x)
if self.tus and self.refiner_vae and conv_carry_in is None:
if self.tus:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@@ -131,7 +104,14 @@ class UpSmpl(nn.Module):
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]
h = h[:, :, 1:, :, :]
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
@@ -142,165 +122,109 @@ class UpSmpl(nn.Module):
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
x = x[:, :, 1:, :, :]
xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
x = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = x.shape
nc = c // (r1 * 2 * 2)
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
if self.tus and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
return h + sc
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = NoPadConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
if self.refiner_vae:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.ffactor_temporal:
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
x = xl
else:
x = [x]
out = []
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
if self.refiner_vae:
out = self.regul(out)[0]
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = NoPadConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0]
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -308,56 +232,36 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
ch = nxt
self.up.append(stage)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
def forward(self, z):
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
if self.refiner_vae:
x = torch.split(x, 2, dim=2)
else:
x = [ x ]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
x1 = [ F.silu(self.norm_out(x1)) ]
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
del x
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))

View File

@@ -3,11 +3,12 @@ from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -237,6 +238,20 @@ class FeedForward(nn.Module):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
@@ -266,8 +281,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@@ -291,17 +306,12 @@ class BasicTransformerBlock(nn.Module):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
@@ -317,35 +327,41 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32
device = indices_grid.device
dtype = torch.float32 #self.dtype
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
start = 1
end = theta
device = fractional_positions.device
# Pad if dim is not divisible by 6
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
@@ -485,7 +501,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = torch.addcmul(x, x, scale).add_(shift)
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(

View File

@@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
@@ -531,22 +531,10 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

@@ -1,120 +0,0 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@@ -1,157 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x

View File

@@ -1,156 +0,0 @@
from typing import Literal
import torch
import torch.nn as nn
from .distributions import DiagonalGaussianDistribution
from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.norm_fn = norm_fn
# mel = librosa_mel_fn(sr=self.sampling_rate,
# n_fft=self.n_fft,
# n_mels=self.num_mels,
# fmin=self.fmin,
# fmax=self.fmax)
# mel_basis = torch.from_numpy(mel).float()
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
hann_window = torch.hann_window(self.win_size)
self.register_buffer('mel_basis', mel_basis)
self.register_buffer('hann_window', hann_window)
@property
def device(self):
return self.mel_basis.device
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1),
[int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2)],
mode='reflect')
waveform = waveform.squeeze(1)
spec = torch.stft(waveform,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec, self.norm_fn)
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
bigvgan_config = {
"resblock": "1",
"num_mels": 80,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5],
],
"activation": "snakebeta",
"snake_logscale": True,
}
self.vocoder = BigVGANVocoder(
bigvgan_config
).eval()
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
# x: (B * L)
mel = self.mel_converter(x)
dist = self.vae.encode(mel)
return dist
@torch.no_grad()
def decode(self, z):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)
return dist.mean

View File

@@ -1,219 +0,0 @@
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))
])
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))
])
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
class BigVGANVocoder(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super().__init__()
if isinstance(h, dict):
h = SimpleNamespace(**h)
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@@ -1,92 +0,0 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@@ -1,358 +0,0 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@@ -1,121 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
else:
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.optimized_attention = vae_attention()
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=1).unbind(2)
h = self.optimized_attention(q, k, v)
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x

View File

@@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)

View File

@@ -10,7 +10,6 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -135,34 +134,33 @@ class Attention(nn.Module):
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -236,10 +234,10 @@ class QwenImageTransformerBlock(nn.Module):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
@@ -248,20 +246,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@@ -419,7 +413,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)

View File

@@ -232,13 +232,11 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32
# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
@@ -589,7 +587,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@@ -602,22 +600,10 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@@ -643,7 +629,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@@ -916,7 +902,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads: int,
num_heads=int,
need_global=True,
dtype=None,
device=None,

View File

@@ -468,46 +468,55 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
self.clear_cache()
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -657,51 +657,51 @@ class WanVAE(nn.Module):
)
def encode(self, x):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.encoder)
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
conv_idx = [0]
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
self.clear_cache()
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
@@ -715,3 +715,12 @@ class WanVAE(nn.Module):
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -134,11 +134,10 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
@@ -197,14 +196,8 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def process_timestep(self, timestep, **kwargs):
return timestep
@@ -333,14 +326,6 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@@ -684,6 +669,7 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -712,6 +698,7 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -1536,94 +1523,3 @@ class HunyuanImage21Refiner(HunyuanImage21):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(True)
return out
class HunyuanVideo15(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
return out
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
image = utils.resize_to_batch_size(image, noise.shape[0])
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
else:
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out

View File

@@ -6,20 +6,6 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@@ -186,16 +172,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
# HunyuanVideo 1.5
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["use_cond_type_embedding"] = True
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
else:
dit_config["vision_in_dim"] = None
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
@@ -237,7 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
dit_config["nerf_tile_size"] = 32
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else:
@@ -389,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
@@ -725,12 +701,6 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config
def unet_prefix_from_state_dict(state_dict):

View File

@@ -89,7 +89,6 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml
directml_enabled = True
device_index = args.directml
@@ -331,21 +330,13 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@@ -353,11 +344,11 @@ try:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
except:
@@ -379,9 +370,6 @@ try:
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
try:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
@@ -504,7 +492,6 @@ class LoadedModel:
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@@ -690,10 +677,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@@ -941,7 +925,11 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d
if d == torch.bfloat16 and should_use_bf16(device):
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
return d
return torch.float32
@@ -1003,6 +991,12 @@ def device_supports_non_blocking(device):
return False
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
@@ -1017,16 +1011,6 @@ if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@@ -1035,17 +1019,21 @@ def get_offload_stream(device):
if device in STREAMS:
ss = STREAMS[device]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return ss[stream_counter]
return s
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
@@ -1054,14 +1042,18 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None or current_stream(device) is None:
if stream is None:
return
current_stream(device).wait_stream(stream)
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1086,79 +1078,6 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor) is not torch.nn.parameter.Parameter:
return False
if not is_device_cpu(tensor.device):
return False
if tensor.is_pinned():
#NOTE: Cuda does detect when a tensor is already pinned and would
#error below, but there are proven cases where this also queues an error
#on the GPU async. So dont trust the CUDA API and guard here
return False
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
return False
if size != size_stored:
logging.warning("Size of pinned tensor changed")
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention
@@ -1411,7 +1330,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False

View File

@@ -123,30 +123,16 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
def __init__(self, key, patches):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
def get_key_weight(model, key):
set_func = None
@@ -238,7 +224,6 @@ class ModelPatcher:
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -276,9 +261,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -298,7 +280,6 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@@ -455,19 +436,6 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -636,21 +604,6 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -672,11 +625,9 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
load_completely = []
offloaded = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
@@ -693,7 +644,6 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@@ -707,19 +657,16 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
m.weight_function = [LowVramPatch(weight_key, self.patches)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
m.bias_function = [LowVramPatch(bias_key, self.patches)]
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
@@ -750,9 +697,7 @@ class ModelPatcher:
continue
for param in params:
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -760,17 +705,11 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -807,7 +746,6 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@@ -843,7 +781,7 @@ class ModelPatcher:
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
@@ -887,19 +825,11 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
m.weight_function.append(LowVramPatch(weight_key, self.patches))
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
cast_weight = True
if cast_weight:
@@ -909,13 +839,9 @@ class ModelPatcher:
memory_freed += module_mem
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@@ -928,9 +854,6 @@ class ModelPatcher:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True)
@@ -1318,6 +1241,5 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@@ -21,23 +21,17 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS:
def calculate_input(self, sigma, noise):
sigma = reshape_sigma(sigma, noise.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -51,12 +45,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@@ -64,15 +58,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = reshape_sigma(sigma, latent.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma)
class X0(EPS):
@@ -86,16 +80,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = reshape_sigma(sigma, noise.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = reshape_sigma(sigma, model_output.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise

View File

@@ -1,91 +0,0 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@@ -24,18 +24,13 @@ import comfy.float
import comfy.rmsnorm
import contextlib
def run_every_op():
if torch.compiler.is_compiling():
return
comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@@ -55,90 +50,49 @@ try:
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if comfy.model_management.is_nvidia():
cudnn_version = torch.backends.cudnn.version()
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.")
except:
pass
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if bias_has_function:
if has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
with wf_context:
weight = weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
@@ -151,13 +105,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -168,13 +119,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -185,13 +133,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -201,23 +146,11 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
return out
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -228,13 +161,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -246,17 +176,13 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -269,18 +195,13 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -296,15 +217,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -320,15 +238,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -343,14 +258,10 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -401,18 +312,20 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
input_dtype = input.dtype
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
@@ -424,20 +337,23 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
uncast_bias_weight(self, w, bias, offload_stream)
return o
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None
@@ -457,10 +373,8 @@ class fp8_ops(manual_cast):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -488,14 +402,12 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None:
return out
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
weight, bias = cast_bias_weight(self, input)
if weight.numel() < input.numel(): #TODO: optimize
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
@@ -504,10 +416,8 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
@@ -534,120 +444,7 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(curr_value, value)
merged_dict[key] = merge_nested_dicts(value, curr_value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:

View File

@@ -1,545 +0,0 @@
import torch
import logging
from typing import Tuple, Dict
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@@ -4,9 +4,13 @@ import comfy.samplers
import comfy.utils
import numpy as np
import logging
import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@@ -17,29 +21,10 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None):
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
return torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = comfy.nested_tensor.NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
noises = torch.cat(noises, axis=0)
return noises
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

View File

@@ -306,10 +306,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
@@ -782,7 +789,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@@ -792,7 +799,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area
for k in conds:
@@ -962,11 +969,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@@ -980,7 +987,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
@@ -994,7 +1001,7 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()
@@ -1007,12 +1014,6 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
noise, _ = comfy.utils.pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@@ -1032,7 +1033,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
@@ -1040,9 +1041,6 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
if len(latent_shapes) > 1:
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output

View File

@@ -18,7 +18,6 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import yaml
import math
@@ -143,9 +142,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -279,13 +275,8 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
else:
VAE_KL_MEM_RATIO = 1.0
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
@@ -296,12 +287,10 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
self.crop_input = True
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -343,51 +332,35 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@@ -441,20 +414,20 @@ class VAE:
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.latent_channels = 32
self.latent_channels = 64
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = False
self.not_video = True
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -553,25 +526,6 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:
mode = '16k'
else:
mode = '44k'
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
self.latent_channels = 20
self.output_channels = 2
self.upscale_ratio = 512 * (44100 / sample_rate)
self.downscale_ratio = 512 * (44100 / sample_rate)
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -599,25 +553,12 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
@@ -695,7 +636,6 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -711,13 +651,6 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -764,7 +697,6 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@@ -786,13 +718,6 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@@ -911,7 +836,6 @@ class CLIPType(Enum):
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -934,7 +858,6 @@ class TEModel(Enum):
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -957,8 +880,6 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1063,10 +984,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
@@ -1127,9 +1044,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -1280,7 +1194,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1314,7 +1228,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is not None:
new_sd = sd
@@ -1348,10 +1262,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else:
unet_dtype = dtype
if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
@@ -1367,8 +1278,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
def load_diffusion_model(unet_path, model_options={}):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@@ -468,7 +468,6 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@@ -523,12 +522,6 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
for i in range(amount):
tokens.insert(0, (self.pad_token, 1.0, 0))
else:
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
@@ -607,7 +600,7 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
self.pad_tokens(batch, remaining_length)
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = []
if self.start_token is not None:
@@ -621,11 +614,11 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if min_padding is not None:
self.pad_tokens(batch, min_padding)
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
self.pad_tokens(batch, self.max_length - len(batch))
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if min_length is not None and len(batch) < min_length:
self.pad_tokens(batch, min_length - len(batch))
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

View File

@@ -1374,54 +1374,6 @@ class HunyuanImage21Refiner(HunyuanVideo):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
class HunyuanVideo15(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
}
sampling_settings = {
"shift": 7.0,
}
memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
"in_channels": 98,
}
sampling_settings = {
"shift": 2.0,
}
memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@@ -50,7 +50,6 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod

View File

@@ -1,7 +1,6 @@
from comfy import sd1_clip
import comfy.model_management
import comfy.text_encoders.llama
from .hunyuan_image import HunyuanImageTokenizer
from transformers import LlamaTokenizerFast
import torch
import os
@@ -74,14 +73,6 @@ class HunyuanVideoTokenizer:
return {}
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()

Some files were not shown because too many files have changed in this diff Show More