Compare commits

..

91 Commits

Author SHA1 Message Date
Alexander Piskun
917177e821 move assets related stuff to "app/assets" folder (#10184) 2025-10-03 11:53:26 -07:00
Alexander Piskun
fd6ac0a765 drop PgSQL 14, unite migration for SQLite and PgSQL (#10165) 2025-10-03 11:34:06 -07:00
Alexander Piskun
94941c50b3 move alembic_db inside app folder (#10163) 2025-10-02 15:01:16 -07:00
Jedrzej Kosinski
fbba2e59e5 Satisfy ruff 2025-09-26 20:39:23 -07:00
Jedrzej Kosinski
adccfb2dfd Remove populate_db_with_asset from load_torch_file for now, as nothing yet uses the hashes 2025-09-26 20:33:46 -07:00
Jedrzej Kosinski
9f4c0f3afe Merge branch 'master' into asset-management 2025-09-26 20:24:25 -07:00
Jedrzej Kosinski
ca39552954 Merge branch 'master' into asset-management 2025-09-24 23:44:57 -07:00
Jedrzej Kosinski
4dd843d36f Merge branch 'master' into asset-management 2025-09-18 14:08:20 -07:00
Jedrzej Kosinski
46fdd636de Merge pull request #9545 from bigcat88/asset-management
[Assets] Initial implementation
2025-09-18 14:07:18 -07:00
bigcat88
283cd27bdc final adjustments 2025-09-18 10:05:32 +03:00
bigcat88
1a37d1476d refactor(6): fully batched initial scan 2025-09-17 20:29:29 +03:00
bigcat88
f9602457d6 optimization: initial scan speed(batching metadata[filename]) 2025-09-17 16:47:27 +03:00
bigcat88
85ef08449d optimization: initial scan speed(batching tags) 2025-09-17 14:08:57 +03:00
bigcat88
5b6810a2c6 fixed hash calculation during model loading in ComfyUI 2025-09-17 13:25:56 +03:00
bigcat88
621faaa195 refactor(5): use less DB queries to create seed asset 2025-09-17 10:46:21 +03:00
bigcat88
d0aa64d57b refactor(4): use one query to init DB with all tags for assets 2025-09-16 21:18:18 +03:00
bigcat88
677a0e2508 refactor(3): unite logic for Asset fast check 2025-09-16 20:29:50 +03:00
bigcat88
31ec744317 refactor(2)/fix: skip double checking the existing files during fast check 2025-09-16 19:50:21 +03:00
bigcat88
a336c7c165 refactor(1): use general fast_asset_file_check helper for fast check 2025-09-16 19:19:18 +03:00
bigcat88
77332d3054 optimization: fast scan: commit to the DB in chunks 2025-09-16 14:21:40 +03:00
bigcat88
24a95f5ca4 removed default scanning of "input" and "output" folders; added separate endpoint for test suite. 2025-09-16 11:28:29 +03:00
bigcat88
0be513b213 fix: escape "_" symbol in all other places 2025-09-15 20:26:48 +03:00
bigcat88
f1fb7432a0 fix+test: escape "_" symbol in assets filtering 2025-09-15 19:19:47 +03:00
bigcat88
f3cf99d10c fix+test: escape "_" symbol in tags filtering 2025-09-15 17:29:27 +03:00
bigcat88
5f187fe6fb optimization: make list_unhashed_candidates_under_prefixes single-query instead of N+1 2025-09-15 12:46:35 +03:00
bigcat88
025fc49b4e optimization: DB Queries (Tags) 2025-09-15 10:26:13 +03:00
bigcat88
7becb84341 fixed tests on SQLite file 2025-09-14 23:01:17 +03:00
bigcat88
dda31de690 rework: AssetInfo.name is only a display name 2025-09-14 21:53:44 +03:00
bigcat88
1d970382f0 added final tests 2025-09-14 20:02:28 +03:00
bigcat88
a2fc2bbae4 corrected formatting 2025-09-14 18:12:00 +03:00
bigcat88
a7f2546558 fix: use ".rowcount" instead of ".returning" on SQLite 2025-09-14 17:55:02 +03:00
bigcat88
6cfa94ec58 fixed metadata[filename] feature + new tests for this 2025-09-14 16:28:14 +03:00
bigcat88
a2ec1f7637 simplify code 2025-09-14 15:31:42 +03:00
bigcat88
0b795dc7a7 removed non-needed code 2025-09-14 15:14:24 +03:00
bigcat88
47f7c7ee8c rework + add test for concurrent AssetInfo delete 2025-09-14 15:08:29 +03:00
bigcat88
cdd8d16075 +2 tests for checking Asset downloading logic 2025-09-14 14:57:24 +03:00
bigcat88
37b81e6658 fixed new PgSQL bug 2025-09-14 14:30:38 +03:00
bigcat88
975650060f concurrency upload test + fixed 2 related bugs 2025-09-14 09:39:23 +03:00
bigcat88
4a713654cd added more tests for the Assets logic 2025-09-14 09:10:59 +03:00
bigcat88
9b8e88ba6e added more tests for the Assets logic 2025-09-13 20:09:45 +03:00
bigcat88
bb9ed04758 global refactoring; add support for Assets without the computed hash 2025-09-13 16:39:08 +03:00
Alexander Piskun
934377ac1e removed currently unnecessary "asset_locations" functionality 2025-09-12 14:46:22 +03:00
Alexander Piskun
3c9bf39c20 Merge pull request #1 from bigcat88/asset-management-ci
fix bugs + GH CI tests
2025-09-10 16:31:11 +03:00
bigcat88
0df1ccac6f GitHub CI test for Assets 2025-09-10 16:22:22 +03:00
bigcat88
72548a8ac4 added additional tests; sorted tests 2025-09-10 10:39:55 +03:00
bigcat88
6eaed072c7 add some logic tests 2025-09-10 09:51:06 +03:00
bigcat88
a9096f6c97 removed non-needed code, fix tests, +1 new test 2025-09-09 20:54:11 +03:00
bigcat88
964de8a8ad add more list_assets tests + fix one found bug 2025-09-09 20:35:18 +03:00
bigcat88
1886f10e19 add download tests 2025-09-09 19:30:58 +03:00
bigcat88
357193f7b5 fixed metadata filtering + tests 2025-09-09 19:12:11 +03:00
bigcat88
0ef73e95fd fixed validation error + more tests 2025-09-09 16:02:39 +03:00
bigcat88
faa1e4de17 fixed another test 2025-09-09 15:17:03 +03:00
bigcat88
dfb5703d40 feat: remove Asset when there is no references left + bugfixes + more tests 2025-09-09 15:10:07 +03:00
bigcat88
0e9de2b7c9 feat: add first test 2025-09-08 20:43:45 +03:00
bigcat88
e3311c9229 feat: support for in-memory SQLite databases 2025-09-08 18:15:09 +03:00
bigcat88
3fa0fc496c fix: use UPSERT to eliminate rare race condition during ingesting many small files in parallel 2025-09-08 18:13:32 +03:00
bigcat88
6282d495ca corrected detection of missing files for assets 2025-09-07 22:08:38 +03:00
bigcat88
b8ef9bb92c add detection of the missing files for existing assets 2025-09-07 16:49:39 +03:00
bigcat88
2d9be462d3 add support for assets duplicates 2025-09-06 19:22:51 +03:00
bigcat88
789a62ce35 assume that DB packages always present; refactoring & cleanup 2025-09-06 17:44:01 +03:00
bigcat88
84384ca0b4 temporary restore ModelManager 2025-09-05 23:02:26 +03:00
bigcat88
ce270ba090 added Assets Autoscan feature 2025-09-05 17:46:09 +03:00
bigcat88
bf8363ec87 always autofill "filename" in the metadata 2025-08-29 19:48:42 +03:00
bigcat88
6b86be320a use UUID instead of autoincrement Integer for Assets ID field 2025-08-28 08:22:54 +03:00
bigcat88
bdf4ba24ce removed not needed "assets.updated_at" column 2025-08-27 21:58:17 +03:00
bigcat88
871e41aec6 removed not needed "refcount" column 2025-08-27 21:36:31 +03:00
bigcat88
eb7008a4d3 removed not used "added_by" column 2025-08-27 21:26:35 +03:00
bigcat88
0379eff0b5 allow Upload Asset endpoint to accept hash (as documentation requires) 2025-08-27 21:18:26 +03:00
bigcat88
026b7f209c add "--multi-user" support 2025-08-27 19:47:55 +03:00
bigcat88
7c1b0be496 add Get Asset endpoint 2025-08-27 09:58:12 +03:00
bigcat88
6fade5da38 add AssetsResolver support 2025-08-26 20:58:04 +03:00
bigcat88
a763cbd39d add upload asset endpoint 2025-08-25 16:35:29 +03:00
bigcat88
09dabf95bc refactoring: use the same code for "scan task" and realtime DB population 2025-08-25 13:31:56 +03:00
bigcat88
d7464e9e73 implemented assets scaner 2025-08-24 19:29:21 +03:00
bigcat88
a82577f64a auto-creation of tags and fixed population DB when cloned asset is already present 2025-08-24 16:36:01 +03:00
bigcat88
f2ea0bc22c added create_asset_from_hash endpoint 2025-08-24 14:15:21 +03:00
bigcat88
0755e5320a remove timezone; download asset, delete asset endpoints 2025-08-24 12:36:20 +03:00
bigcat88
8d46bec951 use Pydantic for output; finished Tags endpoints 2025-08-24 11:02:30 +03:00
bigcat88
5c1b5973ac dev: refactor; populate models in more nodes; use Pydantic in endpoints for input validation 2025-08-23 20:14:22 +03:00
bigcat88
f92307cd4c dev: Everything is Assets 2025-08-23 19:21:52 +03:00
Jedrzej Kosinski
c708d0a433 Merge branch 'master' into asset-management 2025-08-18 12:12:15 -07:00
Jedrzej Kosinski
1aa089e0b6 More progress on brainstorming code for asset management for models 2025-08-12 17:59:16 -07:00
Jedrzej Kosinski
f032c1a50a Brainstorming abstraction for asset management stuff 2025-08-11 22:25:42 -07:00
Jedrzej Kosinski
3089936a2c Merge branch 'master' into pysssss-model-db 2025-08-11 14:09:21 -07:00
Jedrzej Kosinski
cd679129e3 Merge branch 'master' into pysssss-model-db 2025-08-07 21:12:30 -07:00
pythongosssss
d7062277a7 fix bad merge 2025-08-03 16:40:27 +01:00
pythongosssss
54cf14cbbb Merge remote-tracking branch 'origin/master' into pysssss-model-db 2025-08-03 16:36:49 +01:00
pythongosssss
7d5160f92c Tidy 2025-06-01 15:45:15 +01:00
pythongosssss
7f7b3f1695 tidy 2025-06-01 15:41:00 +01:00
pythongosssss
9da6aca0d0 Add additional db model metadata fields and model downloading function 2025-06-01 15:32:13 +01:00
pythongosssss
1cb3c98947 Implement database & model hashing 2025-06-01 15:32:02 +01:00
118 changed files with 10907 additions and 3950 deletions

View File

@@ -1,27 +0,0 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@@ -1,61 +0,0 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
python_minor: "13"
python_patch: "6"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@@ -21,28 +21,3 @@ jobs:
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@@ -2,53 +2,17 @@
name: "Release Stable Version"
on:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
cu:
description: 'CUDA version'
required: true
type: string
default: "cu129"
default: "129"
python_minor:
description: 'Python minor version'
required: true
@@ -59,21 +23,7 @@ on:
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs:
package_comfy_windows:
@@ -92,15 +42,15 @@ jobs:
id: cache
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
cu${{ inputs.cu }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv cu${{ inputs.cu }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf ${{ inputs.cache_tag }}_python_deps.tar
tar xf cu${{ inputs.cu }}_python_deps.tar
pwd
ls
@@ -115,19 +65,12 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
cd ..
@@ -142,18 +85,14 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@@ -162,9 +101,10 @@ jobs:
ls
- name: Upload binaries to release
uses: softprops/action-gh-release@v2
uses: svenstaro/upload-release-action@v2
with:
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
draft: true
overwrite_files: true

173
.github/workflows/test-assets.yml vendored Normal file
View File

@@ -0,0 +1,173 @@
name: Asset System Tests
on:
push:
paths:
- 'app/**'
- 'tests-assets/**'
- '.github/workflows/test-assets.yml'
- 'requirements.txt'
pull_request:
branches: [master]
workflow_dispatch:
permissions:
contents: read
env:
PIP_DISABLE_PIP_VERSION_CHECK: '1'
PYTHONUNBUFFERED: '1'
jobs:
sqlite:
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
sqlite_mode: ['memory', 'file']
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for SQLite
id: setdb
shell: bash
run: |
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
else
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
mkdir -p "$(dirname "$DBFILE")"
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
fi
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn
postgres:
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
pgsql: ['16', '18']
services:
postgres:
image: postgres:${{ matrix.pgsql }}
env:
POSTGRES_DB: assets
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U postgres -d assets"
--health-interval 10s
--health-timeout 5s
--health-retries 12
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
pip install greenlet psycopg
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for PostgreSQL
shell: bash
run: |
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn

View File

@@ -56,8 +56,7 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@@ -1,64 +0,0 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause

View File

@@ -81,7 +81,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..

View File

@@ -176,12 +176,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@@ -206,32 +200,14 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux)
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
### Intel GPUs (Windows and Linux)
@@ -257,7 +233,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
#### Troubleshooting
@@ -288,6 +264,12 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@@ -3,7 +3,7 @@
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
script_location = app/alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time

View File

@@ -2,13 +2,12 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from app.assets.database.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,

View File

@@ -0,0 +1,175 @@
"""initial assets schema
Revision ID: 0001_assets
Revises:
Create Date: 2025-08-20 00:00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

4
app/assets/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .api.routes import register_assets_system
from .scanner import sync_seed_assets
__all__ = ["sync_seed_assets", "register_assets_system"]

225
app/assets/_helpers.py Normal file
View File

@@ -0,0 +1,225 @@
import contextlib
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Optional, Sequence
import folder_paths
from .api import schemas_in
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> Optional[str]:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None
def new_scan_id(root: schemas_in.RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

544
app/assets/api/routes.py Normal file
View File

@@ -0,0 +1,544 @@
import contextlib
import logging
import os
import urllib.parse
import uuid
from typing import Optional
from aiohttp import web
from pydantic import ValidationError
import folder_paths
from ... import user_manager
from .. import manager, scanner
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
USER_MANAGER: Optional[user_manager.UserManager] = None
LOGGER = logging.getLogger(__name__)
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = await manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
qp = request.rel_url.query
query_dict = {}
if "include_tags" in qp:
query_dict["include_tags"] = qp.getall("include_tags")
if "exclude_tags" in qp:
query_dict["exclude_tags"] = qp.getall("exclude_tags")
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
v = qp.get(k)
if v is not None:
query_dict[k] = v
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = await manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = await manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = await manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: Optional[str] = None
tags_raw: list[str] = []
provided_name: Optional[str] = None
user_metadata_raw: Optional[str] = None
provided_hash: Optional[str] = None
provided_hash_exists: Optional[bool] = None
file_written = 0
tmp_path: Optional[str] = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = await manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = await manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = await manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
async def set_asset_preview(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.SetPreviewBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.set_asset_preview(
asset_info_id=asset_info_id,
preview_asset_id=body.preview_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (PermissionError, ValueError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = await manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
LOGGER.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as ve:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
status=400,
)
result = await manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
try:
await scanner.sync_seed_assets(body.roots)
except Exception:
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response({"synced": True, "roots": body.roots}, status=200)
@ROUTES.post("/api/assets/scan/schedule")
async def schedule_asset_scan(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
states = await scanner.schedule_scans(body.roots)
return web.json_response(states.model_dump(mode="json"), status=202)
@ROUTES.get("/api/assets/scan")
async def get_asset_scan_status(request: web.Request) -> web.Response:
root = request.query.get("root", "").strip().lower()
states = scanner.current_statuses()
if root in {"models", "input", "output"}:
states = [s for s in states.scans if s.root == root] # type: ignore
states = schemas_out.AssetScanStatusResponse(scans=states)
return web.json_response(states.model_dump(mode="json"), status=200)
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})

View File

@@ -0,0 +1,297 @@
import json
import uuid
from typing import Any, Literal, Optional
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: Optional[str] = None
# Accept either a JSON string (query param) or a dict
metadata_filter: Optional[dict[str, Any]] = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class UpdateAssetBody(BaseModel):
name: Optional[str] = None
tags: Optional[list[str]] = None
user_metadata: Optional[dict[str, Any]] = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.tags is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, tags, user_metadata.")
if self.tags is not None:
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
raise ValueError("Field 'tags' must be an array of strings.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
v = v.strip()
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: Optional[str] = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
preview_url: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: Optional[str]
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: Optional[str] = None
created_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class AssetScanError(BaseModel):
path: str
message: str
at: Optional[str] = Field(None, description="ISO timestamp")
class AssetScanStatus(BaseModel):
scan_id: str
root: Literal["models", "input", "output"]
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
scheduled_at: Optional[str] = None
started_at: Optional[str] = None
finished_at: Optional[str] = None
discovered: int = 0
processed: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)
class AssetScanStatusResponse(BaseModel):
scans: list[AssetScanStatus] = Field(default_factory=list)

View File

View File

@@ -0,0 +1,25 @@
from .bulk_ops import seed_from_paths_batch
from .escape_like import escape_like_prefix
from .fast_check import fast_asset_file_check
from .filters import apply_metadata_filter, apply_tag_filters
from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv
from .tags import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
remove_missing_tag_for_asset_id,
)
__all__ = [
"apply_tag_filters",
"apply_metadata_filter",
"escape_like_prefix",
"fast_asset_file_check",
"is_scalar",
"project_kv",
"ensure_tags_exist",
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"seed_from_paths_batch",
"visible_owner_clause",
]

View File

@@ -0,0 +1,230 @@
import os
import uuid
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
from ..timeutil import utcnow
MAX_BIND_PARAMS = 800
async def seed_from_paths_batch(
session: AsyncSession,
*,
specs: Sequence[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
dialect = session.bind.dialect.name
if dialect not in ("sqlite", "postgresql"):
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
await session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
else:
ins_state = (
d_pg.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
else:
ins_info = (
d_pg.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
async def bulk_insert_tags_and_meta(
session: AsyncSession,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
dialect = session.bind.dialect.name
if tag_rows:
if dialect == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
await session.execute(ins_links, chunk)
if meta_rows:
if dialect == "sqlite":
ins_meta = (
d_sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif dialect == "postgresql":
ins_meta = (
d_pg.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
await session.execute(ins_meta, chunk)
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))

View File

@@ -0,0 +1,7 @@
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape

View File

@@ -0,0 +1,19 @@
import os
from typing import Optional
def fast_asset_file_check(
*,
mtime_db: Optional[int],
size_db: Optional[int],
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True

View File

@@ -0,0 +1,87 @@
from typing import Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Optional[Sequence[str]],
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@@ -0,0 +1,12 @@
import sqlalchemy as sa
from ..models import AssetInfo
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@@ -0,0 +1,64 @@
from decimal import Decimal
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -0,0 +1,90 @@
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def add_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def remove_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
) -> None:
await session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@@ -0,0 +1,251 @@
import uuid
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
from .timeutil import utcnow
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list["AssetCacheState"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Optional[Asset]] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -0,0 +1,57 @@
from .content import (
check_fs_asset_exists_quick,
compute_hash_and_dedup_for_cache_state,
ingest_fs_asset,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset,
touch_asset_infos_by_fs_path,
)
from .info import (
add_tags_to_asset_info,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_tags,
list_asset_infos_page,
list_tags_with_usage,
remove_tags_from_asset_info,
replace_asset_info_metadata_projection,
set_asset_info_preview,
set_asset_info_tags,
touch_asset_info_by_id,
update_asset_info_full,
)
from .queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
get_cache_state_by_asset_id,
list_cache_states_by_asset_id,
pick_best_live_path,
)
__all__ = [
# queries
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
"get_cache_state_by_asset_id",
"list_cache_states_by_asset_id",
"pick_best_live_path",
# info
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
"update_asset_info_full", "replace_asset_info_metadata_projection",
"touch_asset_info_by_id", "delete_asset_info_by_id",
"add_tags_to_asset_info", "remove_tags_from_asset_info",
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content
"check_fs_asset_exists_quick",
"redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
"list_cache_states_with_asset_under_prefixes",
]

View File

@@ -0,0 +1,721 @@
import contextlib
import logging
import os
from datetime import datetime
from typing import Any, Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from ..._helpers import compute_relative_filename
from ...storage import hashing as hashing_mod
from ..helpers import (
ensure_tags_exist,
escape_like_prefix,
remove_missing_tag_for_asset_id,
)
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
from .info import replace_asset_info_metadata_projection
from .queries import list_cache_states_by_asset_id, pick_best_live_path
async def check_fs_asset_exists_quick(
session: AsyncSession,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
locator = os.path.abspath(file_path)
stmt = (
sa.select(sa.literal(True))
.select_from(AssetCacheState)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(
AssetCacheState.file_path == locator,
Asset.hash.isnot(None),
AssetCacheState.needs_verify.is_(False),
)
.limit(1)
)
conds = []
if mtime_ns is not None:
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
if conds:
stmt = stmt.where(*conds)
return (await session.execute(stmt)).first() is not None
async def redirect_all_references_then_delete_asset(
session: AsyncSession,
*,
duplicate_asset_id: str,
canonical_asset_id: str,
) -> None:
"""
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
- Otherwise, simply repoint the AssetInfo.asset_id.
- Always retarget AssetCacheState rows.
- Finally delete the duplicate Asset row.
"""
if duplicate_asset_id == canonical_asset_id:
return
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
dup_infos = (
await session.execute(
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
)
).unique().scalars().all()
for info in dup_infos:
# Try to find an existing collision on canonical
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == canonical_asset_id,
AssetInfo.owner_id == info.owner_id,
AssetInfo.name == info.name,
)
.limit(1)
)
).unique().scalars().first()
if existing:
merged_meta = dict(existing.user_metadata or {})
other_meta = info.user_metadata or {}
for k, v in other_meta.items():
if k not in merged_meta:
merged_meta[k] = v
if merged_meta != (existing.user_metadata or {}):
await replace_asset_info_metadata_projection(
session,
asset_info_id=existing.id,
user_metadata=merged_meta,
)
existing_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
)
).all()
}
from_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
)
).all()
}
to_add = sorted(from_tags - existing_tags)
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
now = utcnow()
session.add_all([
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
for t in to_add
])
await session.flush()
if existing.preview_id is None and info.preview_id is not None:
existing.preview_id = info.preview_id
if info.last_access_time and (
existing.last_access_time is None or info.last_access_time > existing.last_access_time
):
existing.last_access_time = info.last_access_time
existing.updated_at = utcnow()
await session.flush()
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
await session.delete(info)
await session.flush()
else:
# Simple retarget
info.asset_id = canonical_asset_id
info.updated_at = utcnow()
await session.flush()
# 2) Repoint cache states and previews
await session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.asset_id == duplicate_asset_id)
.values(asset_id=canonical_asset_id)
)
await session.execute(
sa.update(AssetInfo)
.where(AssetInfo.preview_id == duplicate_asset_id)
.values(preview_id=canonical_asset_id)
)
# 3) Remove duplicate Asset
dup = await session.get(Asset, duplicate_asset_id)
if dup:
await session.delete(dup)
await session.flush()
async def compute_hash_and_dedup_for_cache_state(
session: AsyncSession,
*,
state_id: int,
) -> Optional[str]:
"""
Compute hash for the given cache state, deduplicate, and settle verify cases.
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
"""
state = await session.get(AssetCacheState, state_id)
if not state:
return None
path = state.file_path
try:
if not os.path.isfile(path):
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
asset = await session.get(Asset, state.asset_id)
await session.delete(state)
await session.flush()
if asset and asset.hash is None:
remaining = (
await session.execute(
sa.select(sa.func.count())
.select_from(AssetCacheState)
.where(AssetCacheState.asset_id == asset.id)
)
).scalar_one()
if int(remaining or 0) == 0:
await session.delete(asset)
await session.flush()
else:
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
return None
digest = await hashing_mod.blake3_hash(path)
new_hash = f"blake3:{digest}"
st = os.stat(path, follow_symlinks=True)
new_size = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
# Current asset of this state
this_asset = await session.get(Asset, state.asset_id)
# If the state got orphaned somehow (race), just reattach appropriately.
if not this_asset:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
state.asset_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
state.asset_id = new_asset.id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
await session.flush()
return state.asset_id
# 1) Seed asset case (hash is NULL): claim or merge into canonical
if this_asset.hash is None:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
# Merge seed asset into canonical (safe, collision-aware)
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# No canonical: try to claim the hash; handle races with a SAVEPOINT
try:
async with session.begin_nested():
this_asset.hash = new_hash
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
await session.flush()
except IntegrityError:
# Someone else claimed it concurrently; fetch canonical and merge
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# If we got here, the integrity error was not about hash uniqueness
raise
# Claimed successfully
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# 2) Verify case for hashed assets
if this_asset.hash == new_hash:
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
target_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
target_id = new_asset.id
state.asset_id = target_id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
await session.flush()
return target_id
except Exception:
raise
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
if session.bind.dialect.name == "postgresql":
stmt = (
sa.select(AssetCacheState.id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
.distinct(AssetCacheState.asset_id)
)
else:
first_id = sa.func.min(AssetCacheState.id).label("first_id")
stmt = (
sa.select(first_id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.group_by(AssetCacheState.asset_id)
.order_by(first_id.asc())
)
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
async def list_verify_candidates_under_prefixes(
session: AsyncSession, *, prefixes: Sequence[str]
) -> Union[list[int], Sequence[int]]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
return (
await session.execute(
sa.select(AssetCacheState.id)
.where(AssetCacheState.needs_verify.is_(True))
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: str = "",
preview_id: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
dialect = session.bind.dialect.name
if preview_id:
if not await session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
if dialect == "sqlite":
res = await session.execute(
d_sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
elif dialect == "postgresql":
res = await session.execute(
d_pg.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(
index_elements=[Asset.hash],
index_where=Asset.__table__.c.hash.isnot(None),
)
.returning(Asset.id)
)
inserted_id = res.scalar_one_or_none()
if inserted_id:
out["asset_created"] = True
asset = await session.get(Asset, inserted_id)
else:
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
res = await session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = await session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
async with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
await session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
await session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
await session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
await ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
await session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
file_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
locator = os.path.abspath(file_path)
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(AssetCacheState)
.where(
AssetCacheState.asset_id == AssetInfo.asset_id,
AssetCacheState.file_path == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
await session.execute(stmt.values(last_access_time=ts))
async def list_cache_states_with_asset_under_prefixes(
session: AsyncSession,
*,
prefixes: Sequence[str],
) -> list[tuple[AssetCacheState, Optional[str], int]]:
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
if not prefixes:
return []
conds = []
for p in prefixes:
if not p:
continue
base = os.path.abspath(p)
if not base.endswith(os.sep):
base = base + os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
if not conds:
return []
rows = (
await session.execute(
select(AssetCacheState, Asset.hash, Asset.size_bytes)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).all()
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
try:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
if not primary_path:
return
new_filename = compute_relative_filename(primary_path)
if not new_filename:
return
infos = (
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
).scalars().all()
for info in infos:
current_meta = info.user_metadata or {}
if current_meta.get("filename") == new_filename:
continue
updated = dict(current_meta)
updated["filename"] = new_filename
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
except Exception:
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)

View File

@@ -0,0 +1,586 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, noload
from ..._helpers import compute_relative_filename, normalize_tags
from ..helpers import (
apply_metadata_filter,
apply_tag_filters,
ensure_tags_exist,
escape_like_prefix,
project_kv,
visible_owner_clause,
)
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from ..timeutil import utcnow
from .queries import (
get_asset_by_hash,
list_cache_states_by_asset_id,
pick_best_live_path,
)
async def list_asset_infos_page(
session: AsyncSession,
*,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((await session.execute(count_stmt)).scalar_one() or 0)
infos = (await session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def fetch_asset_info_and_asset(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = await session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
async def fetch_asset_info_asset_and_tags(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (await session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
async def create_asset_info_for_existing_asset(
session: AsyncSession,
*,
asset_hash: str,
name: str,
user_metadata: Optional[dict] = None,
tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
async with session.begin_nested():
session.add(info)
await session.flush()
except IntegrityError:
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: str,
user_metadata: Optional[dict],
) -> None:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def touch_asset_info_by_id(
session: AsyncSession,
*,
asset_info_id: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
await session.execute(stmt.values(last_access_time=ts))
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((await session.execute(stmt)).rowcount or 0) > 0
async def add_tags_to_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
await ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
async with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
await session.flush()
except IntegrityError:
await nested.rollback()
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
async def remove_tags_from_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
await session.flush()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
async def list_tags_with_usage(
session: AsyncSession,
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (await session.execute(q.limit(limit).offset(offset))).all()
total = (await session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
async def set_asset_info_preview(
session: AsyncSession,
*,
asset_info_id: str,
preview_asset_id: Optional[str],
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not await session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
await session.flush()

View File

@@ -0,0 +1,76 @@
import os
from typing import Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
row = (
await session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
return (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (await session.execute(q)).first() is not None
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_id(
session: AsyncSession, *, asset_id: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path

View File

@@ -0,0 +1,6 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)

556
app/assets/manager.py Normal file
View File

@@ -0,0 +1,556 @@
import contextlib
import logging
import mimetypes
import os
from typing import Optional, Sequence
from comfy_api.internal import async_to_sync
from ..db import create_session
from ._helpers import (
ensure_within_base,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from .api import schemas_in, schemas_out
from .database.models import Asset
from .database.services import (
add_tags_to_asset_info,
asset_exists_by_hash,
asset_info_exists_for_asset_id,
check_fs_asset_exists_quick,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash,
get_asset_info_by_id,
get_asset_tags,
ingest_fs_asset,
list_asset_infos_page,
list_cache_states_by_asset_id,
list_tags_with_usage,
pick_best_live_path,
remove_tags_from_asset_info,
set_asset_info_preview,
touch_asset_info_by_id,
touch_asset_infos_by_fs_path,
update_asset_info_full,
)
from .storage import hashing
async def asset_exists(*, asset_hash: str) -> bool:
async with await create_session() as session:
return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
if not size_bytes:
return
async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
await session.commit()
return
asset_hash = hashing.blake3_hash_sync(abs_path)
async with await create_session() as session:
await ingest_fs_asset(
session,
asset_hash="blake3:" + asset_hash,
abs_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=file_name,
tag_origin="automatic",
tags=tags,
)
await session.commit()
async def list_assets(
*,
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
async with await create_session() as session:
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
async def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: Optional[str] = None,
owner_id: str = "",
expected_asset_hash: Optional[str] = None,
) -> schemas_out.AssetCreated:
try:
digest = await hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
async with await create_session() as session:
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = await create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
async with await create_session() as session:
result = await ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
async def update_asset(
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
await session.commit()
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
async def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: Optional[str],
owner_id: str = "",
) -> schemas_out.AssetDetail:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
await set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
await session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
await session.commit()
return False
if not delete_content_if_orphan or not asset_id:
await session.commit()
return True
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
await session.commit()
return True
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = await session.get(Asset, asset_id)
if asset_row is not None:
await session.delete(asset_row)
await session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
async def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower()
async with await create_session() as session:
asset = await get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = await create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
async def list_tags(
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
async with await create_session() as session:
rows, total = await list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
async def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
await session.commit()
return schemas_out.TagsAdd(**data)
async def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
await session.commit()
return schemas_out.TagsRemove(**data)
def _safe_sort_field(requested: Optional[str]) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str], fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback

501
app/assets/scanner.py Normal file
View File

@@ -0,0 +1,501 @@
import asyncio
import contextlib
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Literal, Optional
import sqlalchemy as sa
import folder_paths
from ..db import create_session
from ._helpers import (
collect_models_files,
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
list_tree,
new_scan_id,
prefixes_for_root,
ts_to_iso,
)
from .api import schemas_in, schemas_out
from .database.helpers import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
escape_like_prefix,
fast_asset_file_check,
remove_missing_tag_for_asset_id,
seed_from_paths_batch,
)
from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import (
compute_hash_and_dedup_for_cache_state,
list_cache_states_by_asset_id,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
)
LOGGER = logging.getLogger(__name__)
SLOW_HASH_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: schemas_in.RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
file_errors: list[dict] = field(default_factory=list)
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
def current_statuses() -> schemas_out.AssetScanStatusResponse:
scans = []
for root in schemas_in.ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
scans.append(_scan_progress_to_scan_status_model(prog))
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
results: list[ScanProgress] = []
for root in roots:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
RUNNING_TASKS[root] = asyncio.create_task(
_run_hash_verify_pipeline(root, prog, state),
name=f"asset-scan:{root}",
)
results.append(prog)
return _status_response_for(results)
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
t_total = time.perf_counter()
created = 0
skipped_existing = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as ex:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
ap = os.path.abspath(p)
if ap in existing_paths:
skipped_existing += 1
continue
try:
st = os.stat(ap, follow_symlinks=True)
except OSError:
continue
if not st.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(ap)
specs.append(
{
"abs_path": ap,
"size_bytes": st.st_size,
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(ap),
}
)
for t in tags:
tag_pool.add(t)
if not specs:
return
async with await create_session() as sess:
if tag_pool:
await ensure_tags_exist(sess, tag_pool, tag_type="user")
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
await sess.commit()
finally:
LOGGER.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_total,
created,
skipped_existing,
len(paths),
)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=ts_to_iso(progress.scheduled_at),
started_at=ts_to_iso(progress.started_at),
finished_at=ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
prog.status = "running"
prog.started_at = time.time()
try:
prefixes = prefixes_for_root(root)
await _fast_db_consistency_pass(root)
# collect candidates from DB
async with await create_session() as sess:
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
# dedupe: prioritize verification first
seen = set()
ordered: list[int] = []
for lst in (verify_ids, unhashed_ids):
for sid in lst:
if sid not in seen:
seen.add(sid)
ordered.append(sid)
prog.discovered = len(ordered)
# queue up work
for sid in ordered:
await state.queue.put(sid)
state.closed = True
_start_state_workers(root, prog, state)
await _await_state_workers_then_finish(root, prog, state)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
_append_error(prog, path="", message=str(exc))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None)
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
"""
Detect missing files quickly and toggle 'missing' tag per asset_id.
Rules:
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
- We consider ALL cache states of the asset (across roots) before tagging.
"""
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
elif root == "input":
bases = [folder_paths.get_input_directory()]
else:
bases = [folder_paths.get_output_directory()]
try:
async with await create_session() as sess:
# state + hash + size for the current root
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
# Track fast_ok within the scanned root and whether the asset is hashed
by_asset: dict[str, dict[str, bool]] = {}
for state, a_hash, size_db in rows:
aid = state.asset_id
acc = by_asset.get(aid)
if acc is None:
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
by_asset[aid] = acc
try:
if acc["hashed"]:
st = os.stat(state.file_path, follow_symlinks=True)
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
acc["any_fast_ok_here"] = True
except FileNotFoundError:
pass
except OSError as e:
_append_error(prog, path=state.file_path, message=str(e))
# Decide per asset, considering ALL its states (not just this root)
for aid, acc in by_asset.items():
try:
if not acc["hashed"]:
# Never tag seed assets as missing
continue
any_fast_ok_global = acc["any_fast_ok_here"]
if not any_fast_ok_global:
# Check other states outside this root
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
for st in others:
try:
any_fast_ok_global = fast_asset_file_check(
mtime_db=st.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(st.file_path, follow_symlinks=True),
)
except OSError:
continue
if any_fast_ok_global:
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
else:
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception as ex:
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
await sess.commit()
except Exception as e:
_append_error(prog, path="", message=f"reconcile failed: {e}")
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
if state.workers:
return
async def _worker(_wid: int):
while True:
sid = await state.queue.get()
try:
if sid is None:
return
try:
async with await create_session() as sess:
# Optional: fetch path for better error messages
st = await sess.get(AssetCacheState, sid)
try:
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
await sess.commit()
except Exception as e:
path = st.file_path if st else f"state:{sid}"
_append_error(prog, path=path, message=str(e))
raise
except Exception:
pass
finally:
prog.processed += 1
finally:
state.queue.task_done()
state.workers = [
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
for i in range(SLOW_HASH_CONCURRENCY)
]
async def _close_when_ready():
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_ready())
async def _await_state_workers_then_finish(
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
await _reconcile_missing_tags_for_root(root, prog)
prog.finished_at = time.time()
prog.status = "completed"
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"at": ts_to_iso(time.time()),
})
async def _fast_db_consistency_pass(
root: schemas_in.RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> Optional[set[str]]:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
async with await create_session() as sess:
rows = (
await sess.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = await sess.get(Asset, aid)
if asset:
await sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
await sess.commit()
return survivors if collect_existing_paths else None

View File

View File

@@ -0,0 +1,72 @@
import asyncio
import os
from typing import IO, Union
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
"""Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = None
if hasattr(file_obj, "tell"):
orig_pos = file_obj.tell()
try:
if hasattr(file_obj, "seek"):
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if hasattr(file_obj, "seek") and orig_pos is not None:
file_obj.seek(orig_pos)
def blake3_hash_sync(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
if hasattr(fp, "read"):
return _hash_file_obj_sync(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
async def blake3_hash(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
return await asyncio.to_thread(_worker)

View File

@@ -1,112 +0,0 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

View File

@@ -1,14 +0,0 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

255
app/db.py Normal file
View File

@@ -0,0 +1,255 @@
import logging
import os
import shutil
from contextlib import asynccontextmanager
from typing import Optional
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from comfy.cli_args import args
LOGGER = logging.getLogger(__name__)
ENGINE: Optional[AsyncEngine] = None
SESSION: Optional[async_sessionmaker] = None
def _root_paths():
"""Resolve alembic.ini and migrations script folder."""
root_path = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
return config_path, scripts_path
def _absolutize_sqlite_url(db_url: str) -> str:
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
try:
u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
db_path: str = u.database or ""
if isinstance(db_path, str) and db_path.startswith("file:"):
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
"""
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
Returns: (normalized_url, is_memory)
"""
try:
u = make_url(db_url)
except Exception:
return db_url, False
if not u.drivername.startswith("sqlite"):
return db_url, False
db = u.database or ""
if db == ":memory:":
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
return str(u), True
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
if "uri=true" not in db:
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
return str(u), True
return str(u), False
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
"""Return the on-disk path for a SQLite URL, else None."""
try:
u = make_url(sync_url)
except Exception:
return None
if not u.drivername.startswith("sqlite"):
return None
db_path = u.database
if isinstance(db_path, str) and db_path.startswith("file:"):
return None # Not a real file if it is a URI like "file:...?"
return db_path
def _get_alembic_config(sync_url: str) -> Config:
"""Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
db_url = _absolutize_sqlite_url(db_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
if is_mem:
connect_args["uri"] = True
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
if not is_mem:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(database_url=db_url, connect_args=connect_args)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(database_url: str, connect_args: dict) -> None:
if database_url.find("postgresql+psycopg") == -1:
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(database_url)
driver = u.drivername
if not driver.startswith("sqlite+aiosqlite"):
raise ValueError(f"Unsupported DB driver: {driver}")
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
database_url = _absolutize_sqlite_url(database_url)
cfg = _get_alembic_config(database_url)
engine = create_engine(database_url, future=True, connect_args=connect_args)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head()
if target_rev is None:
LOGGER.warning("Alembic: no target revision found.")
return
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(database_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try:
shutil.copy(sqlite_path, backup_path)
except Exception as exc:
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
try:
command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path)
except Exception as re:
LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
LOGGER.exception("Error upgrading database, backup is not available.")
raise
def get_engine():
"""Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@@ -42,7 +42,6 @@ def get_installed_frontend_version():
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
@@ -64,7 +63,6 @@ def get_required_frontend_version():
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
@@ -198,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
@@ -205,39 +214,17 @@ class FrontendManager:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@@ -258,6 +245,15 @@ comfyui-frontend-package is not installed.
@classmethod
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
Returns:
str: The path to the workflow templates directory.
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@@ -293,11 +289,16 @@ comfyui-workflow-templates is not installed.
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
tuple[str, str, str]: A tuple containing (owner, repo, version).
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
@@ -314,18 +315,22 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initializes the frontend for the specified version.
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@@ -377,13 +382,17 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Args:
version_string (str): The version string to initialize the frontend with.
version_string (str): The version string specifying which frontend to use.
Returns:
str: The path of the initialized frontend.
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)

View File

@@ -212,7 +212,8 @@ parser.add_argument(
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@@ -23,6 +23,8 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@@ -35,6 +37,10 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@@ -67,8 +73,10 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@@ -83,7 +91,9 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@@ -91,6 +101,7 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@@ -37,10 +37,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
@@ -17,12 +17,11 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tds=True):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.tds = tds
self.gs = fct * ic // oc
@@ -31,7 +30,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1
h = self.conv(x)
if self.tds and self.refiner_vae:
if self.tds:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@@ -67,7 +66,6 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@@ -85,11 +83,10 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tus=True):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.tus = tus
self.rp = fct * oc // ic
@@ -98,7 +95,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1
h = self.conv(x)
if self.tus and self.refiner_vae:
if self.tus:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@@ -151,56 +148,43 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
for stage in self.down:
@@ -216,42 +200,31 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0]
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -262,26 +235,25 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
ch = nxt
self.up.append(stage)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
def forward(self, z):
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@@ -292,10 +264,4 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out
return self.conv_out(F.silu(self.norm_out(x)))

View File

@@ -237,7 +237,6 @@ class WanAttentionBlock(nn.Module):
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
@@ -903,7 +902,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads: int,
num_heads=int,
need_global=True,
dtype=None,
device=None,

View File

@@ -468,46 +468,55 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
self.clear_cache()
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -332,51 +332,35 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@@ -652,7 +636,6 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -668,13 +651,6 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -721,7 +697,6 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@@ -743,13 +718,6 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4

View File

@@ -50,10 +50,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
@@ -66,6 +72,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@@ -93,6 +101,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
# populate_db_with_asset(ckpt) # surprise tool that can help us later - performs hashing on model file
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@@ -392,6 +392,20 @@ class MultiCombo(ComfyTypeI):
})
return to_return
@comfytype(io_type="ASSET")
class Asset(ComfyTypeI):
class Input(WidgetInput):
def __init__(self, id: str, query_tags: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.query_tags = query_tags
def as_dict(self):
to_return = super().as_dict() | prune_dict({
"query_tags": self.query_tags
})
return to_return
@comfytype(io_type="IMAGE")
class Image(ComfyTypeIO):
Type = torch.Tensor
@@ -1605,7 +1619,6 @@ class _IO:
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoder = AudioEncoder
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen

View File

@@ -2,7 +2,6 @@
# filename: filtered-openapi.yaml
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations
from datetime import date, datetime
@@ -1321,7 +1320,6 @@ class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum):
@@ -1356,7 +1354,6 @@ class KlingVideoGenModelName(str, Enum):
kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoResult(BaseModel):

View File

@@ -95,7 +95,6 @@ import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
@@ -500,9 +499,7 @@ class ApiClient:
else:
raise ValueError("File must be BytesIO or str path")
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
@@ -535,7 +532,7 @@ class ApiClient:
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)

View File

@@ -4,18 +4,16 @@ import os
import datetime
import json
import logging
import re
import hashlib
from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""Ensures the API log directory exists within ComfyUI's temp directory and returns its path."""
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
@@ -26,77 +24,42 @@ def get_log_directory():
return base_temp_dir
return log_dir
def _sanitize_filename_component(name: str) -> str:
if not name:
return "log"
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore
sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed
if not sanitized:
sanitized = "log"
return sanitized
def _short_hash(*parts: str, length: int = 10) -> str:
return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length]
def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str:
"""Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id
h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL
# Compute how much room we have for the slug given the directory length
# Keep total path length reasonably below ~260 on Windows.
max_total_path = 240
prefix = f"{timestamp}_"
suffix = f"_{h}.log"
if not slug:
slug = "op"
max_filename_len = max(60, max_total_path - len(log_dir) - 1)
max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix))
if len(slug) > max_slug_len:
slug = slug[:max_slug_len].rstrip(" ._-")
return os.path.join(log_dir, f"{prefix}{slug}{suffix}")
def _format_data_for_logging(data: Any) -> str:
def _format_data_for_logging(data):
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode("utf-8") # Try to decode as text
return data.decode('utf-8') # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: Any = None,
request_data: any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: Any = None,
error_message: str | None = None,
response_content: any = None,
error_message: str | None = None
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
@@ -106,7 +69,7 @@ def log_request_response(
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
if request_data:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
@@ -114,7 +77,7 @@ def log_request_response(
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
if response_content:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
@@ -126,7 +89,6 @@ def log_request_response(
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__':
# Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG)

View File

@@ -52,3 +52,7 @@ class RodinResourceItem(BaseModel):
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

File diff suppressed because it is too large Load Diff

View File

@@ -920,7 +920,7 @@ class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
options=[Image2VideoModelName.seedance_1_lite.value],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
),

View File

@@ -39,7 +39,6 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string,
bytesio_to_image_tensor,
)
from comfy_api.util import VideoContainer, VideoCodec
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
@@ -311,7 +310,7 @@ class GeminiNode(ComfyNodeABC):
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
@@ -491,6 +490,7 @@ class GeminiInputFiles(ComfyNodeABC):
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(

View File

@@ -423,8 +423,6 @@ class KlingTextToVideoNode(KlingNodeBase):
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
"pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"),
"pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"),
}
@classmethod
@@ -712,9 +710,6 @@ class KlingImage2VideoNode(KlingNodeBase):
# Camera control type for image 2 video is always `simple`
camera_control.type = KlingCameraControlType.simple
if mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value:
mode = "pro" # October 5: currently "std" mode is not supported for this model
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
@@ -52,186 +51,174 @@ def image_result_url_extractor(response: LumaGeneration):
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(comfy_io.ComfyNode):
class LumaReferenceNode(ComfyNodeABC):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as reference.",
),
comfy_io.Float.Input(
"weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of image reference.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"luma_ref",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
RETURN_TYPES = (LumaIO.LUMA_REF,)
RETURN_NAMES = ("luma_ref",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_luma_reference"
CATEGORY = "api node/image/Luma"
@classmethod
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> comfy_io.NodeOutput:
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as reference.",
},
),
"weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of image reference.",
},
),
},
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
}
def create_luma_reference(
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
):
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return comfy_io.NodeOutput(luma_ref)
return (luma_ref,)
class LumaConceptsNode(comfy_io.ComfyNode):
class LumaConceptsNode(ComfyNodeABC):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"concept1",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept2",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept3",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept4",
options=get_luma_concepts(include_none=True),
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to add to the ones chosen here.",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
RETURN_NAMES = ("luma_concepts",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_concepts"
CATEGORY = "api node/video/Luma"
@classmethod
def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"concept1": (get_luma_concepts(include_none=True),),
"concept2": (get_luma_concepts(include_none=True),),
"concept3": (get_luma_concepts(include_none=True),),
"concept4": (get_luma_concepts(include_none=True),),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
},
),
},
}
def create_concepts(
self,
concept1: str,
concept2: str,
concept3: str,
concept4: str,
luma_concepts: LumaConceptChain = None,
) -> comfy_io.NodeOutput:
):
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain)
return comfy_io.NodeOutput(chain)
return (chain,)
class LumaImageGenerationNode(comfy_io.ComfyNode):
class LumaImageGenerationNode(ComfyNodeABC):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Float.Input(
"style_image_weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of style image. Ignored if no style_image provided.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"image_luma_ref",
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
optional=True,
),
comfy_io.Image.Input(
"style_image",
tooltip="Style reference image; only 1 image will be used.",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
optional=True,
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"model": ([model.value for model in LumaImageModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
"style_image_weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of style image. Ignored if no style_image provided.",
},
),
},
"optional": {
"image_luma_ref": (
LumaIO.LUMA_REF,
{
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
},
),
"style_image": (
IO.IMAGE,
{"tooltip": "Style reference image; only 1 image will be used."},
),
"character_image": (
IO.IMAGE,
{
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
model: str,
aspect_ratio: str,
@@ -240,29 +227,27 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
) -> comfy_io.NodeOutput:
unique_id: str = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
api_image_ref = await self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
api_style_ref = await self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=auth_kwargs,
character_image, max_images=4, auth_kwargs=kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
@@ -283,7 +268,7 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = await operation.execute()
@@ -298,19 +283,18 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return comfy_io.NodeOutput(img)
return (img,)
@classmethod
async def _convert_luma_refs(
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
@@ -324,84 +308,82 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(comfy_io.ComfyNode):
class LumaImageModifyNode(ComfyNodeABC):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Float.Input(
"image_weight",
default=0.1,
min=0.0,
max=0.98,
step=0.01,
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"image_weight": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 0.98,
"step": 0.01,
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
},
),
"model": ([model.value for model in LumaImageModel],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
model: str,
image: torch.Tensor,
image_weight: float,
seed,
) -> comfy_io.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
unique_id: str = None,
**kwargs,
):
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
image, max_images=1, auth_kwargs=kwargs,
)
image_url = download_urls[0]
# next, make Luma call with download url provided
@@ -419,7 +401,7 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = await operation.execute()
@@ -434,84 +416,88 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return comfy_io.NodeOutput(img)
return (img,)
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
class LumaTextToVideoGenerationNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
model: str,
aspect_ratio: str,
@@ -520,15 +506,13 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
) -> comfy_io.NodeOutput:
unique_id: str = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
@@ -545,12 +529,12 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@@ -563,94 +547,90 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
node_id=unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
class LumaImageToVideoGenerationNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
),
# comfy_io.Combo.Input(
# "aspect_ratio",
# options=[ratio.value for ratio in LumaAspectRatio],
# default=LumaAspectRatio.ratio_16_9,
# ),
comfy_io.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Image.Input(
"first_image",
tooltip="First frame of generated video.",
optional=True,
),
comfy_io.Image.Input(
"last_image",
tooltip="Last frame of generated video.",
optional=True,
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
# "default": LumaAspectRatio.ratio_16_9,
# }),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"first_image": (
IO.IMAGE,
{"tooltip": "First frame of generated video."},
),
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
model: str,
resolution: str,
@@ -660,16 +640,14 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
) -> comfy_io.NodeOutput:
unique_id: str = None,
**kwargs,
):
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@@ -690,12 +668,12 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@@ -708,19 +686,18 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
node_id=unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
@classmethod
async def _convert_to_keyframes(
cls,
self,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
@@ -742,18 +719,23 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
return LumaKeyframes(frame0=frame0, frame1=frame1)
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
LumaImageGenerationNode,
LumaImageModifyNode,
LumaTextToVideoGenerationNode,
LumaImageToVideoGenerationNode,
LumaReferenceNode,
LumaConceptsNode,
]
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"LumaImageNode": LumaImageGenerationNode,
"LumaImageModifyNode": LumaImageModifyNode,
"LumaVideoNode": LumaTextToVideoGenerationNode,
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
"LumaReferenceNode": LumaReferenceNode,
"LumaConceptsNode": LumaConceptsNode,
}
async def comfy_entrypoint() -> LumaExtension:
return LumaExtension()
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"LumaImageNode": "Luma Text to Image",
"LumaImageModifyNode": "Luma Image to Image",
"LumaVideoNode": "Luma Text to Video",
"LumaImageToVideoNode": "Luma Image to Video",
"LumaReferenceNode": "Luma Reference",
"LumaConceptsNode": "Luma Concepts",
}

View File

@@ -2,7 +2,11 @@ import logging
from typing import Any, Callable, Optional, TypeVar
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
@@ -128,6 +132,47 @@ def validate_prompts(
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@@ -454,7 +499,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)

View File

@@ -1,7 +1,5 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from io import BytesIO
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -28,11 +26,12 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
validate_string,
)
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
import torch
import aiohttp
from io import BytesIO
AVERAGE_DURATION_T2V = 32
@@ -73,101 +72,100 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
return response_upload.Resp.img_id
class PixverseTemplateNode(comfy_io.ComfyNode):
class PixverseTemplateNode:
"""
Select template for PixVerse Video generation.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
comfy_io.Combo.Input("template", options=[list(pixverse_templates.keys())]),
],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
RETURN_TYPES = (PixverseIO.TEMPLATE,)
RETURN_NAMES = ("pixverse_template",)
FUNCTION = "create_template"
CATEGORY = "api node/video/PixVerse"
@classmethod
def execute(cls, template: str) -> comfy_io.NodeOutput:
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()),),
}
}
def create_template(self, template: str):
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return comfy_io.NodeOutput(template_id)
return (template_id,)
class PixverseTextToVideoNode(comfy_io.ComfyNode):
class PixverseTextToVideoNode(ComfyNodeABC):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in PixverseAspectRatio],
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
aspect_ratio: str,
quality: str,
@@ -176,7 +174,9 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -186,10 +186,6 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
@@ -207,7 +203,7 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
auth_kwargs=kwargs,
)
response_api = await operation.execute()
@@ -228,8 +224,8 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@@ -237,75 +233,77 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixverseImageToVideoNode(comfy_io.ComfyNode):
class PixverseImageToVideoNode(ComfyNodeABC):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
image: torch.Tensor,
prompt: str,
quality: str,
@@ -314,13 +312,11 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -347,7 +343,7 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
auth_kwargs=kwargs,
)
response_api = await operation.execute()
@@ -368,8 +364,8 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
@@ -377,71 +373,72 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixverseTransitionVideoNode(comfy_io.ComfyNode):
class PixverseTransitionVideoNode(ComfyNodeABC):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("first_frame"),
comfy_io.Image.Input("last_frame"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
@@ -450,14 +447,12 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
motion_mode: str,
seed,
negative_prompt: str = None,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -484,7 +479,7 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=auth,
auth_kwargs=kwargs,
)
response_api = await operation.execute()
@@ -505,8 +500,8 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@@ -514,19 +509,19 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
PixverseTextToVideoNode,
PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
NODE_CLASS_MAPPINGS = {
"PixverseTextToVideoNode": PixverseTextToVideoNode,
"PixverseImageToVideoNode": PixverseImageToVideoNode,
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
"PixverseTemplateNode": PixverseTemplateNode,
}
async def comfy_entrypoint() -> PixVerseExtension:
return PixVerseExtension()
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video",
"PixverseImageToVideoNode": "PixVerse Image to Video",
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}

View File

@@ -38,48 +38,48 @@ from PIL import UnidentifiedImageError
async def handle_recraft_file_request(
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
return all_bytesio
def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict:

File diff suppressed because it is too large Load Diff

View File

@@ -28,12 +28,6 @@ class Text2ImageInputField(BaseModel):
negative_prompt: Optional[str] = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
@@ -55,13 +49,6 @@ class Txt2ImageParametersField(BaseModel):
watermark: bool = Field(True)
class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
@@ -86,12 +73,6 @@ class Text2ImageTaskCreationRequest(BaseModel):
parameters: Txt2ImageParametersField = Field(...)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2ImageInputField = Field(...)
parameters: Image2ImageParametersField = Field(...)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
@@ -154,12 +135,7 @@ async def process_task(
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
node_id: str,
estimated_duration: int,
poll_interval: int,
@@ -312,128 +288,6 @@ class WanTextToImageApi(comfy_io.ComfyNode):
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanImageToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2i-preview"],
default="wan2.5-i2i-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
# redo this later as an optional combo of recommended resolutions
# comfy_io.Int.Input(
# "width",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
# comfy_io.Int.Input(
# "height",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
payload = Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=42,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanTextToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -739,7 +593,6 @@ class WanApiExtension(ComfyExtension):
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
WanTextToImageApi,
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
]

View File

@@ -360,7 +360,7 @@ class RecordAudio:
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )

View File

@@ -1,62 +1,44 @@
import folder_paths
import comfy.audio_encoders.audio_encoders
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AudioEncoderLoader(io.ComfyNode):
class AudioEncoderLoader:
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderLoader",
category="loaders",
inputs=[
io.Combo.Input(
"audio_encoder_name",
options=folder_paths.get_filename_list("audio_encoders"),
),
],
outputs=[io.AudioEncoder.Output()],
)
def INPUT_TYPES(s):
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
}}
RETURN_TYPES = ("AUDIO_ENCODER",)
FUNCTION = "load_model"
@classmethod
def execute(cls, audio_encoder_name) -> io.NodeOutput:
CATEGORY = "loaders"
def load_model(self, audio_encoder_name):
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return io.NodeOutput(audio_encoder)
return (audio_encoder,)
class AudioEncoderEncode(io.ComfyNode):
class AudioEncoderEncode:
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderEncode",
category="conditioning",
inputs=[
io.AudioEncoder.Input("audio_encoder"),
io.Audio.Input("audio"),
],
outputs=[io.AudioEncoderOutput.Output()],
)
def INPUT_TYPES(s):
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
FUNCTION = "encode"
@classmethod
def execute(cls, audio_encoder, audio) -> io.NodeOutput:
CATEGORY = "conditioning"
def encode(self, audio_encoder, audio):
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return io.NodeOutput(output)
return (output,)
class AudioEncoder(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AudioEncoderLoader,
AudioEncoderEncode,
]
async def comfy_entrypoint() -> AudioEncoder:
return AudioEncoder()
NODE_CLASS_MAPPINGS = {
"AudioEncoderLoader": AudioEncoderLoader,
"AudioEncoderEncode": AudioEncoderEncode,
}

View File

@@ -1,41 +1,34 @@
# code adapted from https://github.com/exx8/differential-diffusion
from typing_extensions import override
import torch
from comfy_api.latest import ComfyExtension, io
class DifferentialDiffusion(io.ComfyNode):
class DifferentialDiffusion():
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DifferentialDiffusion",
display_name="Differential Diffusion",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
optional=True,
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
},
"optional": {
"strength": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
@classmethod
def execute(cls, model, strength=1.0) -> io.NodeOutput:
def apply(self, model, strength=1.0):
model = model.clone()
model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
return io.NodeOutput(model)
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
return (model, )
@classmethod
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"]
step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min
@@ -60,13 +53,9 @@ class DifferentialDiffusion(io.ComfyNode):
return binary_mask
class DifferentialDiffusionExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
DifferentialDiffusion,
]
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
return DifferentialDiffusionExtension()
NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}

View File

@@ -1,38 +1,26 @@
import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ReferenceLatent(io.ComfyNode):
class ReferenceLatent:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReferenceLatent",
category="advanced/conditioning/edit_models",
description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Conditioning.Output(),
]
)
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
},
"optional": {"latent": ("LATENT", ),}
}
@classmethod
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return io.NodeOutput(conditioning)
return (conditioning, )
class EditModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ReferenceLatent,
]
def comfy_entrypoint() -> EditModelExtension:
return EditModelExtension()
NODE_CLASS_MAPPINGS = {
"ReferenceLatent": ReferenceLatent,
}

View File

@@ -1,74 +0,0 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class EpsilonScaling(io.ComfyNode):
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Epsilon Scaling",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"scaling_factor",
default=1.005,
min=0.5,
max=1.5,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scaling_factor) -> io.NodeOutput:
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return io.NodeOutput(model_clone)
class EpsilonScalingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
]
async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()

View File

@@ -1,8 +1,6 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@@ -335,28 +333,25 @@ NOISE_LEVELS = {
],
}
class GITSScheduler(io.ComfyNode):
class GITSScheduler:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GITSScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
io.Int.Input("steps", default=10, min=2, max=1000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Sigmas.Output(),
],
)
def INPUT_TYPES(s):
return {"required":
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
@classmethod
def execute(cls, coeff, steps, denoise):
FUNCTION = "get_sigmas"
def get_sigmas(self, coeff, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return io.NodeOutput(torch.FloatTensor([]))
return (torch.FloatTensor([]),)
total_steps = round(steps * denoise)
if steps <= 20:
@@ -367,16 +362,8 @@ class GITSScheduler(io.ComfyNode):
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return io.NodeOutput(torch.FloatTensor(sigmas))
return (torch.FloatTensor(sigmas), )
class GITSSchedulerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler,
}

View File

@@ -1,73 +1,55 @@
from typing_extensions import override
import folder_paths
import comfy.sd
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
class QuadrupleCLIPLoader(io.ComfyNode):
class QuadrupleCLIPLoader:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="QuadrupleCLIPLoader",
category="advanced/loaders",
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
]
)
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@classmethod
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return io.NodeOutput(clip)
return (clip,)
class CLIPTextEncodeHiDream(io.ComfyNode):
class CLIPTextEncodeHiDream:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHiDream",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.String.Input("llama", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
]
)
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
@classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
tokens = clip.tokenize(clip_g)
tokens["l"] = clip.tokenize(clip_l)["l"]
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
tokens["llama"] = clip.tokenize(llama)["llama"]
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
return (clip.encode_from_tokens_scheduled(tokens), )
class HiDreamExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
QuadrupleCLIPLoader,
CLIPTextEncodeHiDream,
]
async def comfy_entrypoint() -> HiDreamExtension:
return HiDreamExtension()
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
}

View File

@@ -1,11 +1,9 @@
#Taken from: https://github.com/tfernd/HyperTile/
import math
from typing_extensions import override
from einops import rearrange
# Use torch rng for consistency across generations
from torch import randint
from comfy_api.latest import ComfyExtension, io
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
@@ -22,31 +20,25 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
return ns[idx]
class HyperTile(io.ComfyNode):
class HyperTile:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HyperTile",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("tile_size", default=256, min=1, max=2048),
io.Int.Input("swap_size", default=2, min=1, max=128),
io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
],
outputs=[
io.Model.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
"scale_depth": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
latent_tile_size = max(32, tile_size) // 8
temp = None
self.temp = None
def hypertile_in(q, k, v, extra_options):
nonlocal temp
model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
@@ -66,15 +58,14 @@ class HyperTile(io.ComfyNode):
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
temp = (nh, nw, h, w)
self.temp = (nh, nw, h, w)
return q, k, v
return q, k, v
def hypertile_out(out, extra_options):
nonlocal temp
if temp is not None:
nh, nw, h, w = temp
temp = None
if self.temp is not None:
nh, nw, h, w = self.temp
self.temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
@@ -85,14 +76,6 @@ class HyperTile(io.ComfyNode):
m.set_model_attn1_output_patch(hypertile_out)
return (m, )
class HyperTileExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
HyperTile,
]
async def comfy_entrypoint() -> HyperTileExtension:
return HyperTileExtension()
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}

View File

@@ -1,30 +1,21 @@
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class InstructPixToPixConditioning(io.ComfyNode):
class InstructPixToPixConditioning:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="InstructPixToPixConditioning",
category="conditioning/instructpix2pix",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
}}
@classmethod
def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
@@ -47,17 +38,8 @@ class InstructPixToPixConditioning(io.ComfyNode):
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1], out_latent)
class InstructPix2PixExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
InstructPixToPixConditioning,
]
async def comfy_entrypoint() -> InstructPix2PixExtension:
return InstructPix2PixExtension()
return (out[0], out[1], out_latent)
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@@ -1,22 +1,20 @@
from typing_extensions import override
import torch
import comfy.model_management as mm
from comfy_api.latest import ComfyExtension, io
class LotusConditioning(io.ComfyNode):
class LotusConditioning:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LotusConditioning",
category="conditioning/lotus",
inputs=[],
outputs=[io.Conditioning.Output(display_name="conditioning")],
)
def INPUT_TYPES(s):
return {
"required": {
},
}
@classmethod
def execute(cls) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("conditioning",)
FUNCTION = "conditioning"
CATEGORY = "conditioning/lotus"
def conditioning(self):
device = mm.get_torch_device()
#lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change
#and getting parity with the reference implementation would otherwise require inference and 800mb of tensors
@@ -24,16 +22,8 @@ class LotusConditioning(io.ComfyNode):
cond = [[prompt_embeds, {}]]
return io.NodeOutput(cond)
return (cond,)
class LotusExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LotusConditioning,
]
async def comfy_entrypoint() -> LotusExtension:
return LotusExtension()
NODE_CLASS_MAPPINGS = {
"LotusConditioning" : LotusConditioning,
}

View File

@@ -1,3 +1,4 @@
import io
import nodes
import node_helpers
import torch
@@ -7,60 +8,46 @@ import comfy.utils
import math
import numpy as np
import av
from io import BytesIO
from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
class EmptyLTXVLatentVideo(io.ComfyNode):
class EmptyLTXVLatentVideo:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyLTXVLatentVideo",
category="latent/video/ltxv",
inputs=[
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
CATEGORY = "latent/video/ltxv"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
return ({"samples": latent}, )
class LTXVImgToVideo(io.ComfyNode):
class LTXVImgToVideo:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVImgToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
}}
@classmethod
def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
@@ -75,7 +62,7 @@ class LTXVImgToVideo(io.ComfyNode):
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
def conditioning_get_any_value(conditioning, key, default=None):
@@ -106,46 +93,35 @@ def get_keyframe_idxs(cond):
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
class LTXVAddGuide:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVAddGuide",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Latent.Input("latent"),
io.Image.Input(
"image",
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
),
io.Int.Input(
"frame_idx",
default=0,
min=-9999,
max=9999,
tooltip="Frame index to start the conditioning at. "
"For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
@classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
@@ -153,8 +129,7 @@ class LTXVAddGuide(io.ComfyNode):
t = vae.encode(encode_pixels)
return encode_pixels, t
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
@@ -166,10 +141,9 @@ class LTXVAddGuide(io.ComfyNode):
return frame_idx, latent_idx
@classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
_, latent_coords = self._patchifier.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
@@ -178,9 +152,8 @@ class LTXVAddGuide(io.ComfyNode):
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
@classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = self.get_latent_index(
cond=positive,
latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2],
@@ -189,8 +162,8 @@ class LTXVAddGuide(io.ComfyNode):
)
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
@@ -203,8 +176,7 @@ class LTXVAddGuide(io.ComfyNode):
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
@classmethod
def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
@@ -223,21 +195,20 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive, negative, latent_image, noise_mask = self.append_keyframe(
positive,
negative,
frame_idx,
@@ -252,9 +223,9 @@ class LTXVAddGuide(io.ComfyNode):
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image, noise_mask = cls.replace_latent_frames(
latent_image, noise_mask = self.replace_latent_frames(
latent_image,
noise_mask,
t,
@@ -262,35 +233,34 @@ class LTXVAddGuide(io.ComfyNode):
strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVCropGuides(io.ComfyNode):
class LTXVCropGuides:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVCropGuides",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT",),
}
}
@classmethod
def execute(cls, positive, negative, latent) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
@@ -298,52 +268,44 @@ class LTXVCropGuides(io.ComfyNode):
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVConditioning(io.ComfyNode):
class LTXVConditioning:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVConditioning",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
@classmethod
def execute(cls, positive, negative, frame_rate) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return io.NodeOutput(positive, negative)
return (positive, negative)
class ModelSamplingLTXV(io.ComfyNode):
class ModelSamplingLTXV:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelSamplingLTXV",
category="advanced/model",
inputs=[
io.Model.Input("model"),
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Model.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
},
"optional": {"latent": ("LATENT",), }
}
@classmethod
def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone()
if latent is None:
@@ -367,41 +329,37 @@ class ModelSamplingLTXV(io.ComfyNode):
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return io.NodeOutput(m)
return (m, )
class LTXVScheduler(io.ComfyNode):
class LTXVScheduler:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Boolean.Input(
id="stretch",
default=True,
tooltip="Stretch the sigmas to be in the range [terminal, 1].",
),
io.Float.Input(
id="terminal",
default=0.1,
min=0.0,
max=0.99,
step=0.01,
tooltip="The terminal value of the sigmas after stretching.",
),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Sigmas.Output(),
],
)
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"stretch": ("BOOLEAN", {
"default": True,
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": (
"FLOAT",
{
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
"tooltip": "The terminal value of the sigmas after stretching."
},
),
},
"optional": {"latent": ("LATENT",), }
}
@classmethod
def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput:
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None:
tokens = 4096
else:
@@ -431,7 +389,7 @@ class LTXVScheduler(io.ComfyNode):
stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched
return io.NodeOutput(sigmas)
return (sigmas,)
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
@@ -465,54 +423,52 @@ def preprocess(image: torch.Tensor, crf=29):
return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with BytesIO() as output_file:
with io.BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with BytesIO(video_bytes) as video_file:
with io.BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
class LTXVPreprocess(io.ComfyNode):
class LTXVPreprocess:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVPreprocess",
category="image",
inputs=[
io.Image.Input("image"),
io.Int.Input(
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"img_compression": (
"INT",
{
"default": 35,
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
),
],
outputs=[
io.Image.Output(display_name="output_image"),
],
)
}
}
@classmethod
def execute(cls, image, img_compression) -> io.NodeOutput:
FUNCTION = "preprocess"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))
return io.NodeOutput(torch.stack(output_images))
return (torch.stack(output_images),)
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyLTXVLatentVideo,
LTXVImgToVideo,
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
]
async def comfy_entrypoint() -> LtxvExtension:
return LtxvExtension()
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
"LTXVAddGuide": LTXVAddGuide,
"LTXVPreprocess": LTXVPreprocess,
"LTXVCropGuides": LTXVCropGuides,
}

View File

@@ -1,27 +1,20 @@
from typing_extensions import override
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import torch
from comfy_api.latest import ComfyExtension, io
class RenormCFG(io.ComfyNode):
class RenormCFG:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RenormCFG",
category="advanced/model",
inputs=[
io.Model.Input("model"),
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput:
CATEGORY = "advanced/model"
def patch(self, model, cfg_trunc, renorm_cfg):
def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
@@ -60,10 +53,10 @@ class RenormCFG(io.ComfyNode):
m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func)
return io.NodeOutput(m)
return (m, )
class CLIPTextEncodeLumina2(io.ComfyNode):
class CLIPTextEncodeLumina2(ComfyNodeABC):
SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts.",
@@ -76,52 +69,36 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts."
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2",
display_name="CLIP Text Encode for Lumina2",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
"that can be used to guide the diffusion model towards generating specific images.",
inputs=[
io.Combo.Input(
"system_prompt",
options=list(cls.SYSTEM_PROMPT.keys()),
tooltip=cls.SYSTEM_PROMPT_TIP,
),
io.String.Input(
"user_prompt",
multiline=True,
dynamic_prompts=True,
tooltip="The text to be encoded.",
),
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
],
outputs=[
io.Conditioning.Output(
tooltip="A conditioning containing the embedded text used to guide the diffusion model.",
),
],
)
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = (IO.CONDITIONING,)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"
@classmethod
def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput:
CATEGORY = "conditioning"
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
def encode(self, clip, user_prompt, system_prompt):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt)
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
return (clip.encode_from_tokens_scheduled(tokens), )
class Lumina2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeLumina2,
RenormCFG,
]
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
"RenormCFG": RenormCFG
}
async def comfy_entrypoint() -> Lumina2Extension:
return Lumina2Extension()
NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
}

View File

@@ -1,29 +1,17 @@
from typing_extensions import override
import torch
import torch.nn.functional as F
from comfy_api.latest import ComfyExtension, io
class Mahiro(io.ComfyNode):
class Mahiro:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
is_experimental=True,
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
def patch(self, model):
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
@@ -42,16 +30,12 @@ class Mahiro(io.ComfyNode):
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return io.NodeOutput(m)
return (m, )
NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}
class MahiroExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Mahiro,
]
async def comfy_entrypoint() -> MahiroExtension:
return MahiroExtension()
NODE_DISPLAY_NAME_MAPPINGS = {
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
}

View File

@@ -1,40 +1,23 @@
from typing_extensions import override
import nodes
import torch
import comfy.model_management
import nodes
from comfy_api.latest import ComfyExtension, io
class EmptyMochiLatentVideo(io.ComfyNode):
class EmptyMochiLatentVideo:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyMochiLatentVideo",
category="latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
return ({"samples":latent}, )
class MochiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyMochiLatentVideo,
]
async def comfy_entrypoint() -> MochiExtension:
return MochiExtension()
NODE_CLASS_MAPPINGS = {
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
}

View File

@@ -1,34 +1,24 @@
import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color
class Morphology(io.ComfyNode):
class Morphology:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Morphology",
display_name="ImageMorphology",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Combo.Input(
"operation",
options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],
),
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
],
outputs=[
io.Image.Output(),
],
)
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],),
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}),
}}
@classmethod
def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1)
@@ -49,63 +39,49 @@ class Morphology(io.ComfyNode):
else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return io.NodeOutput(img_out)
return (img_out,)
class ImageRGBToYUV(io.ComfyNode):
class ImageRGBToYUV:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageRGBToYUV",
category="image/batch",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(display_name="Y"),
io.Image.Output(display_name="U"),
io.Image.Output(display_name="V"),
],
)
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
}}
@classmethod
def execute(cls, image) -> io.NodeOutput:
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("Y", "U", "V")
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB(io.ComfyNode):
class ImageYUVToRGB:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageYUVToRGB",
category="image/batch",
inputs=[
io.Image.Input("Y"),
io.Image.Input("U"),
io.Image.Input("V"),
],
outputs=[
io.Image.Output(),
],
)
def INPUT_TYPES(s):
return {"required": {"Y": ("IMAGE",),
"U": ("IMAGE",),
"V": ("IMAGE",),
}}
@classmethod
def execute(cls, Y, U, V) -> io.NodeOutput:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return io.NodeOutput(out)
return (out,)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
class MorphologyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Morphology,
ImageRGBToYUV,
ImageYUVToRGB,
]
async def comfy_entrypoint() -> MorphologyExtension:
return MorphologyExtension()
NODE_DISPLAY_NAME_MAPPINGS = {
"Morphology": "ImageMorphology",
}

View File

@@ -1,12 +1,9 @@
# from https://github.com/bebebe666/OptimalSteps
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
@@ -26,28 +23,25 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
}
class OptimalStepsScheduler(io.ComfyNode):
class OptimalStepsScheduler:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="OptimalStepsScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
io.Int.Input("steps", default=20, min=3, max=1000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Sigmas.Output(),
],
)
def INPUT_TYPES(s):
return {"required":
{"model_type": (["FLUX", "Wan", "Chroma"], ),
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
@classmethod
def execute(cls, model_type, steps, denoise) ->io.NodeOutput:
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return io.NodeOutput(torch.FloatTensor([]))
return (torch.FloatTensor([]),)
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
@@ -56,16 +50,8 @@ class OptimalStepsScheduler(io.ComfyNode):
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return io.NodeOutput(torch.FloatTensor(sigmas))
return (torch.FloatTensor(sigmas), )
class OptimalStepsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
OptimalStepsScheduler,
]
async def comfy_entrypoint() -> OptimalStepsExtension:
return OptimalStepsExtension()
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler,
}

View File

@@ -3,30 +3,25 @@
#My modified one here is more basic but has less chances of breaking with ComfyUI updates.
from typing_extensions import override
import comfy.model_patcher
import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class PerturbedAttentionGuidance(io.ComfyNode):
class PerturbedAttentionGuidance:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PerturbedAttentionGuidance",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
}
}
@classmethod
def execute(cls, model, scale) -> io.NodeOutput:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(self, model, scale):
unet_block = "middle"
unet_block_id = 0
m = model.clone()
@@ -54,16 +49,8 @@ class PerturbedAttentionGuidance(io.ComfyNode):
m.set_model_sampler_post_cfg_function(post_cfg_function)
return io.NodeOutput(m)
return (m,)
class PAGExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerturbedAttentionGuidance,
]
async def comfy_entrypoint() -> PAGExtension:
return PAGExtension()
NODE_CLASS_MAPPINGS = {
"PerturbedAttentionGuidance": PerturbedAttentionGuidance,
}

View File

@@ -5,9 +5,6 @@ import comfy.samplers
import comfy.utils
import node_helpers
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
pos = noise_pred_pos - noise_pred_nocond
@@ -19,27 +16,20 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
return cfg_result
#TODO: This node should be removed, it has been replaced with PerpNegGuider
class PerpNeg(io.ComfyNode):
class PerpNeg:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PerpNeg",
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
is_deprecated=True,
)
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput:
CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
@@ -60,7 +50,7 @@ class PerpNeg(io.ComfyNode):
m.set_model_sampler_cfg_function(cfg_function)
return io.NodeOutput(m)
return (m, )
class Guider_PerpNeg(comfy.samplers.CFGGuider):
@@ -122,42 +112,35 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
return cfg_result
class PerpNegGuider(io.ComfyNode):
class PerpNegGuider:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PerpNegGuider",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Guider.Output(),
],
is_experimental=True,
)
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"empty_conditioning": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}
}
@classmethod
def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput:
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "_for_testing"
def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
guider = Guider_PerpNeg(model)
guider.set_conds(positive, negative, empty_conditioning)
guider.set_cfg(cfg, neg_scale)
return io.NodeOutput(guider)
return (guider,)
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
"PerpNegGuider": PerpNegGuider,
}
class PerpNegExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerpNeg,
PerpNegGuider,
]
async def comfy_entrypoint() -> PerpNegExtension:
return PerpNegExtension()
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)",
}

View File

@@ -4,8 +4,6 @@ import folder_paths
import comfy.clip_model
import comfy.clip_vision
import comfy.ops
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
VISION_CONFIG_DICT = {
@@ -118,52 +116,41 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
return updated_prompt_embeds
class PhotoMakerLoader(io.ComfyNode):
class PhotoMakerLoader:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerLoader",
category="_for_testing/photomaker",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
outputs=[
io.Photomaker.Output(),
],
is_experimental=True,
)
def INPUT_TYPES(s):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
@classmethod
def execute(cls, photomaker_model_name):
RETURN_TYPES = ("PHOTOMAKER",)
FUNCTION = "load_photomaker_model"
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
data = data["id_encoder"]
photomaker_model.load_state_dict(data)
return io.NodeOutput(photomaker_model)
return (photomaker_model,)
class PhotoMakerEncode(io.ComfyNode):
class PhotoMakerEncode:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerEncode",
category="_for_testing/photomaker",
inputs=[
io.Photomaker.Input("photomaker"),
io.Image.Input("image"),
io.Clip.Input("clip"),
io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
def INPUT_TYPES(s):
return {"required": { "photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",),
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}),
}}
@classmethod
def execute(cls, photomaker, image, clip, text):
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_photomaker"
CATEGORY = "_for_testing/photomaker"
def apply_photomaker(self, photomaker, image, clip, text):
special_token = "photomaker"
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
try:
@@ -191,16 +178,11 @@ class PhotoMakerEncode(io.ComfyNode):
else:
out = cond
return io.NodeOutput([[out, {"pooled_output": pooled}]])
return ([[out, {"pooled_output": pooled}]], )
class PhotomakerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PhotoMakerLoader,
PhotoMakerEncode,
]
NODE_CLASS_MAPPINGS = {
"PhotoMakerLoader": PhotoMakerLoader,
"PhotoMakerEncode": PhotoMakerEncode,
}
async def comfy_entrypoint() -> PhotomakerExtension:
return PhotomakerExtension()

View File

@@ -1,38 +1,24 @@
from typing_extensions import override
import nodes
from comfy_api.latest import ComfyExtension, io
from nodes import MAX_RESOLUTION
class CLIPTextEncodePixArtAlpha(io.ComfyNode):
class CLIPTextEncodePixArtAlpha:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodePixArtAlpha",
category="advanced/conditioning",
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
inputs=[
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
io.String.Input("text", multiline=True, dynamic_prompts=True),
io.Clip.Input("clip"),
],
outputs=[
io.Conditioning.Output(),
],
)
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
@classmethod
def execute(cls, clip, width, height, text):
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text)
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}))
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
class PixArtExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodePixArtAlpha,
]
async def comfy_entrypoint() -> PixArtExtension:
return PixArtExtension()
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
}

View File

@@ -1,29 +1,24 @@
import node_helpers
import comfy.utils
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class TextEncodeQwenImageEdit(io.ComfyNode):
class TextEncodeQwenImageEdit:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEdit",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {"vae": ("VAE", ),
"image": ("IMAGE", ),}}
@classmethod
def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, prompt, vae=None, image=None):
ref_latent = None
if image is None:
images = []
@@ -45,30 +40,28 @@ class TextEncodeQwenImageEdit(io.ComfyNode):
conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return io.NodeOutput(conditioning)
return (conditioning, )
class TextEncodeQwenImageEditPlus(io.ComfyNode):
class TextEncodeQwenImageEditPlus:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEditPlus",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {"vae": ("VAE", ),
"image1": ("IMAGE", ),
"image2": ("IMAGE", ),
"image3": ("IMAGE", ),
}}
@classmethod
def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None):
ref_latents = []
images = [image1, image2, image3]
images_vl = []
@@ -101,17 +94,10 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
conditioning = clip.encode_from_tokens_scheduled(tokens)
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(conditioning)
return (conditioning, )
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus,
]
async def comfy_entrypoint() -> QwenExtension:
return QwenExtension()
NODE_CLASS_MAPPINGS = {
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
"TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus,
}

View File

@@ -1,8 +1,6 @@
import torch
import nodes
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
@@ -22,31 +20,26 @@ def camera_embeddings(elevation, azimuth):
return embeddings
class StableZero123_Conditioning(io.ComfyNode):
class StableZero123_Conditioning:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableZero123_Conditioning",
category="conditioning/3d_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -58,35 +51,30 @@ class StableZero123_Conditioning(io.ComfyNode):
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent})
return (positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched(io.ComfyNode):
class StableZero123_Conditioning_Batched:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableZero123_Conditioning_Batched",
category="conditioning/3d_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -105,32 +93,27 @@ class StableZero123_Conditioning_Batched(io.ComfyNode):
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning(io.ComfyNode):
class SV3D_Conditioning:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SV3D_Conditioning",
category="conditioning/3d_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=21, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -150,17 +133,11 @@ class SV3D_Conditioning(io.ComfyNode):
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent})
return (positive, negative, {"samples":latent})
class Stable3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StableZero123_Conditioning,
StableZero123_Conditioning_Batched,
SV3D_Conditioning,
]
async def comfy_entrypoint() -> Stable3DExtension:
return Stable3DExtension()
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
"SV3D_Conditioning": SV3D_Conditioning,
}

View File

@@ -1,9 +1,7 @@
#Taken from: https://github.com/dbolya/tomesd
import torch
from typing import Tuple, Callable, Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from typing import Tuple, Callable
import math
def do_nothing(x: torch.Tensor, mode:str=None):
@@ -146,45 +144,33 @@ def get_functions(x, ratio, original_shape):
class TomePatchModel(io.ComfyNode):
class TomePatchModel:
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TomePatchModel",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Model.Output()],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, ratio) -> io.NodeOutput:
u: Optional[Callable] = None
CATEGORY = "model_patches/unet"
def patch(self, model, ratio):
self.u = None
def tomesd_m(q, k, v, extra_options):
nonlocal u
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results
m, u = get_functions(q, ratio, extra_options["original_shape"])
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v
def tomesd_u(n, extra_options):
nonlocal u
return u(n)
return self.u(n)
m = model.clone()
m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u)
return io.NodeOutput(m)
return (m, )
class TomePatchModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TomePatchModel,
]
async def comfy_entrypoint() -> TomePatchModelExtension:
return TomePatchModelExtension()
NODE_CLASS_MAPPINGS = {
"TomePatchModel": TomePatchModel,
}

View File

@@ -1,39 +1,23 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel(io.ComfyNode):
class TorchCompileModel:
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="TorchCompileModel",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Combo.Input(
"backend",
options=["inductor", "cudagraphs"],
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
return io.NodeOutput(m)
return (m, )
class TorchCompileExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TorchCompileModel,
]
async def comfy_entrypoint() -> TorchCompileExtension:
return TorchCompileExtension()
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.63"
__version__ = "0.3.60"

View File

@@ -1,70 +1,96 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Example(io.ComfyNode):
class Example:
"""
An example node
A example node
Class methods
-------------
define_schema (io.Schema):
Tell the main program the metadata, input, output parameters of nodes.
fingerprint_inputs:
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
check_lazy_status:
optional method to control list of input names that need to be evaluated.
Attributes
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def define_schema(cls) -> io.Schema:
def INPUT_TYPES(s):
"""
Return a schema which contains all information about the node.
Some types: "Model", "Vae", "Clip", "Conditioning", "Latent", "Image", "Int", "String", "Float", "Combo".
For outputs the "io.Model.Output" should be used, for inputs the "io.Model.Input" can be used.
The type can be a "Combo" - this will be a list for selection.
"""
return io.Schema(
node_id="Example",
display_name="Example Node",
category="Example",
inputs=[
io.Image.Input("image"),
io.Int.Input(
"int_field",
min=0,
max=4096,
step=64, # Slider's step
display_mode=io.NumberDisplay.number, # Cosmetic only: display as "number" or "slider"
lazy=True, # Will only be evaluated if check_lazy_status requires it
),
io.Float.Input(
"float_field",
default=1.0,
min=0.0,
max=10.0,
step=0.01,
round=0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
display_mode=io.NumberDisplay.number,
lazy=True,
),
io.Combo.Input("print_to_screen", options=["enable", "disable"]),
io.String.Input(
"string_field",
multiline=False, # True if you want the field to look like the one on the ClipTextEncode node
default="Hello world!",
lazy=True,
)
],
outputs=[
io.Image.Output(),
],
)
Return a dictionary which contains config for all input fields.
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
The type can be a list for selection.
@classmethod
def check_lazy_status(cls, image, string_field, int_field, float_field, print_to_screen):
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Second value is a config for type "INT", "STRING" or "FLOAT".
"""
return {
"required": {
"image": ("IMAGE",),
"int_field": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64, #Slider's step
"display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}),
"float_field": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 10.0,
"step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!",
"lazy": True
}),
},
}
RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test"
#OUTPUT_NODE = False
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
@@ -81,8 +107,7 @@ class Example(io.ComfyNode):
else:
return []
@classmethod
def execute(cls, image, string_field, int_field, float_field, print_to_screen) -> io.NodeOutput:
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:
string_field aka input text: {string_field}
@@ -91,7 +116,7 @@ class Example(io.ComfyNode):
""")
#do some processing on the image, in this example I just invert it
image = 1.0 - image
return io.NodeOutput(image)
return (image,)
"""
The node will always be re executed if any of the inputs change but
@@ -102,7 +127,7 @@ class Example(io.ComfyNode):
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def fingerprint_inputs(s, image, string_field, int_field, float_field, print_to_screen):
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
@@ -118,13 +143,13 @@ async def get_hello(request):
return web.json_response("hello")
class ExampleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Example,
]
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Example": Example
}
async def comfy_entrypoint() -> ExampleExtension: # ComfyUI calls this to load your extension and its nodes.
return ExampleExtension()
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Example": "Example Node"
}

View File

@@ -279,10 +279,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@@ -295,6 +292,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return full_path
elif os.path.islink(full_path):
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
elif allow_missing:
return full_path
return None
@@ -309,6 +308,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
return full_path
def get_relative_path(full_path: str) -> tuple[str, str] | None:
"""Convert a full path back to a type-relative path.
Args:
full_path: The full path to the file
Returns:
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
"""
global folder_names_and_paths
full_path = os.path.normpath(full_path)
for model_type, (paths, _) in folder_names_and_paths.items():
for base_path in paths:
base_path = os.path.normpath(base_path)
if full_path.startswith(base_path):
relative_path = os.path.relpath(full_path, base_path)
return model_type, relative_path
return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths

Some files were not shown because too many files have changed in this diff Show More