Compare commits

..

100 Commits

Author SHA1 Message Date
Alexander Piskun
81e4dac107 convert nodes_upscale_model.py to V3 schema (#10149) 2025-10-09 16:08:40 -07:00
Alexander Piskun
90853fb9cd convert nodes_flux to V3 schema (#10122) 2025-10-09 16:07:17 -07:00
comfyanonymous
f1dd6e50f8 Fix bug with applying loras on fp8 scaled without fp8 ops. (#10279) 2025-10-09 19:02:40 -04:00
Alexander Piskun
fc0fbf141c convert nodes_sd3.py and nodes_slg.py to V3 schema (#10162) 2025-10-09 15:18:23 -07:00
Alexander Piskun
f3d5d328a3 fix(v3,api-nodes): V3 schema typing; corrected Pika API nodes (#10265) 2025-10-09 15:15:03 -07:00
comfyanonymous
139addd53c More surgical fix for #10267 (#10276) 2025-10-09 16:37:35 -04:00
Alexander Piskun
cbee7d3390 convert nodes_latent.py to V3 schema (#10160) 2025-10-08 23:14:00 -07:00
Alexander Piskun
6732014a0a convert nodes_compositing.py to V3 schema (#10174) 2025-10-08 23:13:15 -07:00
Alexander Piskun
989f715d92 convert nodes_lora_extract.py to V3 schema (#10182) 2025-10-08 23:11:45 -07:00
Alexander Piskun
2ba8d7cce8 convert nodes_model_downscale.py to V3 schema (#10199) 2025-10-08 23:10:23 -07:00
Alexander Piskun
51fb505ffa feat(api-nodes, pylint): use lazy formatting in logging functions (#10248) 2025-10-08 23:06:56 -07:00
Jedrzej Kosinski
72c2071972 Mvly/node update (#10042)
* updated V2V node to allow for control image input
exposing steps in v2v
fixing guidance_scale as input parameter

TODO: allow for motion_intensity as input param.

* refactor: comment out unsupported resolution and adjust default values in video nodes

* set control_after_generate

* adding new defaults

* fixes

* changed control_after_generate back to True

* changed control_after_generate back to False

---------

Co-authored-by: thorsten <thorsten@tripod-digital.co.nz>
2025-10-08 20:30:41 -04:00
comfyanonymous
6e59934089 Refactor model sampling sigmas code. (#10250) 2025-10-08 17:49:02 -04:00
Alexander Piskun
3e0eb8d33f feat(V3-io): allow Enum classes for Combo options (#10237) 2025-10-08 00:14:04 -07:00
comfyanonymous
637221995f ComfyUI version 0.3.64 2025-10-08 00:53:43 -04:00
ComfyUI Wiki
51697d50dc update template to 0.1.94 (#10253) 2025-10-07 19:48:51 -07:00
filtered
19f595b788 Bump frontend to 1.27.10 (#10252) 2025-10-07 17:54:00 -07:00
comfyanonymous
8a15568f10 Temp fix for LTXV custom nodes. (#10251) 2025-10-07 19:55:23 -04:00
Alexander Piskun
9e984c48bc feat(api-nodes): add Sora2 API node (#10249) 2025-10-07 14:11:37 -07:00
Alexander Piskun
fc34c3d112 fix(ReCraft-API-node): allow custom multipart parser to return FormData (#10244) 2025-10-07 13:15:32 -07:00
comfyanonymous
8aea746212 Implement gemma 3 as a text encoder. (#10241)
Not useful yet.
2025-10-06 22:08:08 -04:00
Alexander Piskun
8c19910427 convert nodes_kling.py to V3 schema (#10236) 2025-10-06 16:26:52 -07:00
Alexander Piskun
e77e0a8f8f convert nodes_pika.py to V3 schema (#10216) 2025-10-06 16:20:26 -07:00
Alexander Piskun
a49007a7b0 fix(api-nodes): allow negative_prompt PixVerse to be multiline (#10196) 2025-10-06 16:13:43 -07:00
Alexander Piskun
6ae3515801 fix(api-nodes): enable more pylint rules (#10213) 2025-10-06 16:05:57 -07:00
comfyanonymous
6bd3f8eb9f ComfyUI version 0.3.63 2025-10-06 14:49:04 -04:00
ComfyUI Wiki
7326e46dee Update template to 0.1.93 (#10235)
* Update template to 0.1.92

* Update template to 0.1.93
2025-10-06 10:57:00 -07:00
comfyanonymous
195e0b0639 Remove useless code. (#10223) 2025-10-05 15:41:19 -04:00
Alexander Piskun
187f43696d fix(api-nodes): disable "std" mode for Kling2.5-turbo (#10212) 2025-10-04 23:34:18 -07:00
comfyanonymous
caf07331ff Remove soundfile dependency. No more torchaudio load or save. (#10210) 2025-10-04 22:05:05 -04:00
Alexander Piskun
b1fa1922df convert nodes_stable3d.py to V3 schema (#10204) 2025-10-04 12:33:48 -07:00
Alexander Piskun
2ed74f7ac7 convert nodes_rodin.py to V3 schema (#10195) 2025-10-04 12:29:09 -07:00
Alexander Piskun
22f99fb97e fix(api-nodes): enable 2 more pylint rules, removed non needed code (#10192) 2025-10-04 12:22:57 -07:00
comfyanonymous
bbd683098e Add instructions to install nightly AMD pytorch for windows. (#10190)
* Add instructions to install nightly AMD pytorch for windows.

* Update README.md
2025-10-03 23:37:43 -04:00
comfyanonymous
08726b64fe Update amd nightly command in readme. (#10189) 2025-10-03 18:22:43 -04:00
Finn-Hecker
93d859cfaa Fix type annotation syntax in MotionEncoder_tc __init__ (#10186)
## Summary
Fixed incorrect type hint syntax in `MotionEncoder_tc.__init__()` parameter list.

## Changes
- Line 647: Changed `num_heads=int` to `num_heads: int` 
- This corrects the parameter annotation from a default value assignment to proper type hint syntax

## Details
The parameter was using assignment syntax (`=`) instead of type annotation syntax (`:`), which would incorrectly set the default value to the `int` class itself rather than annotating the expected type.
2025-10-03 14:32:19 -07:00
Alexander Piskun
4614ee09ca convert nodes_edit_model.py to V3 schema (#10147) 2025-10-03 13:24:42 -07:00
Alexander Piskun
5c8e986e27 convert nodes_tomesd.py to V3 schema (#10180) 2025-10-03 11:50:38 -07:00
Alexander Piskun
8c26d7bbe6 convert nodes_pixverse.py to V3 schema (#10177) 2025-10-03 11:48:21 -07:00
Alexander Piskun
d7aa414141 convert nodes_eps.py to V3 schema (#10172) 2025-10-03 11:45:02 -07:00
Alexander Piskun
3e68bc342c convert nodes_torch_compile.py to V3 schema (#10173) 2025-10-03 11:43:54 -07:00
Alexander Piskun
c2c5a7d5f8 fix(api-nodes): bad indentation in Recraft API node function (#10175) 2025-10-03 11:41:06 -07:00
Alexander Piskun
8a293372ec fix(api-nodes): reimport of base64 in Gemini node (#10181) 2025-10-03 11:40:27 -07:00
Alexander Piskun
ed3ca78e08 feat(api-nodes): add kling-2-5-turbo to txt2video and img2video nodes (#10155) 2025-10-03 11:26:34 -07:00
Alexander Piskun
4ffea0e864 feat(linter, api-nodes): add pylint for comfy_api_nodes folder (#10157) 2025-10-02 19:14:28 -04:00
Alexander Piskun
1395bce9f7 update example_node to use V3 schema (#9723) 2025-10-02 15:20:29 -07:00
comfyanonymous
e9364ee279 Turn on TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL by default. (#10168) 2025-10-02 17:57:15 -04:00
Alexander Piskun
f6e3e9a456 fix(api-nodes): made logging path to be smaller (#10156) 2025-10-02 14:50:31 -07:00
Alexander Piskun
8f4ee9984c convert nodes_morphology.py to V3 schema (#10159) 2025-10-02 13:53:00 -07:00
comfyanonymous
0e9d1724be Add a .bat to the AMD portable to disable smart memory. (#10153) 2025-10-02 00:33:05 -04:00
rattus128
4965c0e2ac WAN: Fix cache VRAM leak on error (#10141)
If this suffers an exception (such as a VRAM oom) it will leave the
encode() and decode() methods which skips the cleanup of the WAN
feature cache. The comfy node cache then ultimately keeps a reference
this object which is in turn reffing large tensors from the failed
execution.

The feature cache is currently setup at a class variable on the
encoder/decoder however, the encode and decode functions always clear
it on both entry and exit of normal execution.

Its likely the design intent is this is usable as a streaming encoder
where the input comes in batches, however the functions as they are
today don't support that.

So simplify by bringing the cache back to local variable, so that if
it does VRAM OOM the cache itself is properly garbage when the
encode()/decode() functions dissappear from the stack.
2025-10-01 18:42:16 -04:00
rattus128
911331c06c sd: fix VAE tiled fallback VRAM leak (#10139)
When the VAE catches this VRAM OOM, it launches the fallback logic
straight from the exception context.

Python however refs the entire call stack that caused the exception
including any local variables for the sake of exception report and
debugging. In the case of tensors, this can hold on the references
to GBs of VRAM and inhibit the VRAM allocated from freeing them.

So dump the except context completely before going back to the VAE
via the tiler by getting out of the except block with nothing but
a flag.

The greately increases the reliability of the tiler fallback,
especially on low VRAM cards, as with the bug, if the leak randomly
leaked more than the headroom needed for a single tile, the tiler
would fallback would OOM and fail the flow.
2025-10-01 18:40:28 -04:00
Koratahiu
bb32d4ec31 feat: Add Epsilon Scaling node for exposure bias correction (#10132) 2025-10-01 17:59:07 -04:00
comfyanonymous
a6f83a4a1a Support the new hunyuan vae. (#10150) 2025-10-01 17:19:13 -04:00
Alexander Piskun
e4f99b479a convert nodes_ip2p.pt to V3 schema (#10097) 2025-10-01 12:20:30 -07:00
Alexander Piskun
d9c0a4053d convert nodes_lt.py to V3 schema (#10084) 2025-10-01 12:19:56 -07:00
Alexander Piskun
11bab7be76 convert nodes_pag.py to V3 schema (#10080) 2025-10-01 12:18:49 -07:00
Alexander Piskun
3af1881455 convert nodes_optimalsteps.py to V3 schema (#10074) 2025-10-01 12:18:04 -07:00
Alexander Piskun
e0210ce0a7 convert nodes_differential_diffusion.py to V3 schema (#10056) 2025-10-01 12:17:33 -07:00
Alexander Piskun
7eb7160db4 convert nodes_gits.py to V3 schema (#9949) 2025-10-01 12:16:59 -07:00
Alexander Piskun
638097829d convert nodes_audio_encoder.py to V3 schema (#10123) 2025-09-30 23:00:22 -07:00
AustinMroz
c4a8cf60ab Bump frontend to 1.27.7 (#10133) 2025-09-30 22:12:32 -07:00
comfyanonymous
bab8ba20bf ComfyUI version 0.3.62. 2025-09-30 15:12:07 -04:00
Alexander Piskun
b682a73c55 enable Seedance Pro model in the FirstLastFrame node (#10120) 2025-09-30 10:43:41 -07:00
Alexander Piskun
631b9ae861 fix(Rodin3D-Gen2): missing "task_uuid" parameter (#10128) 2025-09-30 10:21:47 -07:00
comfyanonymous
f48d7230de Add new portable links to readme. (#10112) 2025-09-30 12:17:49 -04:00
comfyanonymous
6e079abc3a Workflow permission fix. (#10110) 2025-09-29 23:11:37 -04:00
comfyanonymous
977a4ed8c5 ComfyUI version 0.3.61 2025-09-29 23:04:42 -04:00
comfyanonymous
414a178fb6 Add basic readme for AMD portable. (#10109) 2025-09-29 23:03:02 -04:00
comfyanonymous
447884b657 Make stable release workflow callable. (#10108) 2025-09-29 20:37:51 -04:00
comfyanonymous
bed4b49d08 Add action to do the full stable release. (#10107) 2025-09-29 20:31:15 -04:00
comfyanonymous
342cf644ce Add a way to have different names for stable nvidia portables. (#10106) 2025-09-29 20:05:44 -04:00
comfyanonymous
3758848423 Different base files for nvidia and amd portables. (#10105) 2025-09-29 19:54:37 -04:00
comfyanonymous
0db6aabed3 Different base files for different release. (#10104) 2025-09-29 19:54:05 -04:00
comfyanonymous
1673ace19b Make the final release test optional in the stable release action. (#10103) 2025-09-29 19:08:42 -04:00
comfyanonymous
7f38e4c538 Add action to create cached deps with manually specified torch. (#10102) 2025-09-29 17:27:52 -04:00
Alexander Piskun
8accf50908 convert nodes_mahiro.py to V3 schema (#10070) 2025-09-29 12:35:51 -07:00
Christian Byrne
ed0f4a609b dont cache new locale entry points (#10101) 2025-09-29 12:16:02 -07:00
Alexander Piskun
041b8824f5 convert nodes_perpneg.py to V3 schema (#10081) 2025-09-29 12:05:28 -07:00
Alexander Piskun
b1111c2062 convert nodes_mochi.py to V3 schema (#10069) 2025-09-29 12:03:35 -07:00
Alexander Piskun
05a258efd8 add WanImageToImageApi node (#10094) 2025-09-29 12:01:04 -07:00
ComfyUI Wiki
c8276f8c6b Update template to 0.1.91 (#10096) 2025-09-29 11:59:42 -07:00
Changrz
6ec1cfe101 [Rodin3d api nodes] Updated the name of the save file path (changed from timestamp to UUID). (#10011)
* Update savepath name from time to uuid

* delete lib
2025-09-29 11:59:12 -07:00
comfyanonymous
b60dc31627 Update command to install latest nighly pytorch. (#10085) 2025-09-28 13:41:32 -04:00
comfyanonymous
555f902fc1 Fix stable workflow creating multiple draft releases. (#10067) 2025-09-27 22:43:25 -04:00
Rui Wang (王瑞)
1364548c72 feat: ComfyUI can be run on the specified Ascend NPU (#9663)
* feature: Set the Ascend NPU to use a single one

* Enable the `--cuda-device` parameter to support both CUDA and Ascend NPUs simultaneously.

* Make the code just set the ASCENT_RT_VISIBLE_DEVICES environment variable without any other edits to master branch

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-09-27 22:36:02 -04:00
Alexander Piskun
2dadb34860 convert nodes_hypertile.py to V3 schema (#10061) 2025-09-27 19:16:22 -07:00
Alexander Piskun
1cf86f5ae5 convert nodes_lumina2.py to V3 schema (#10058) 2025-09-27 19:12:51 -07:00
Alexander Piskun
a1127b232d convert nodes_lotus.py to V3 schema (#10057) 2025-09-27 19:11:36 -07:00
comfyanonymous
896f2e653c Fix typo in release workflow. (#10066) 2025-09-27 21:30:35 -04:00
comfyanonymous
40ae495ddc Improvements to the stable release workflow. (#10065) 2025-09-27 20:28:49 -04:00
rattus128
653ceab414 Reduce Peak WAN inference VRAM usage - part II (#10062)
* flux: math: Use _addcmul to avoid expensive VRAM intermediate

The rope process can be the VRAM peak and this intermediate
for the addition result before releasing the original can OOM.
addcmul_ it.

* wan: Delete the self attention before cross attention

This saves VRAM when the cross attention and FFN are in play as the
VRAM peak.
2025-09-27 18:14:16 -04:00
Alexander Piskun
160698eb41 convert nodes_qwen.py to V3 schema (#10049) 2025-09-27 12:25:35 -07:00
Alexander Piskun
7eca95657c convert nodes_photomaker.py to V3 schema (#10017) 2025-09-27 02:36:43 -07:00
Alexander Piskun
ad5aef2d0c convert nodes_pixart.py to V3 schema (#10019) 2025-09-27 02:34:32 -07:00
Alexander Piskun
bcfd80dd79 convert nodes_luma.py to V3 schema (#10030) 2025-09-27 02:28:11 -07:00
Alexander Piskun
6b4b671ce7 convert nodes_bfl.py to V3 schema (#10033) 2025-09-27 02:27:01 -07:00
Alexander Piskun
a9cf1cd249 convert nodes_hidream.py to V3 schema (#9946) 2025-09-26 23:13:05 -07:00
Christian Byrne
255572188f Add workflow templates version tracking to system_stats (#9089)
Adds installed and required workflow templates version information to the
/system_stats endpoint, allowing the frontend to detect and notify users
when their templates package is outdated.

- Add get_installed_templates_version() and get_required_templates_version()
  methods to FrontendManager
- Include templates version info in system_stats response
- Add comprehensive unit tests for the new functionality
2025-09-26 21:29:13 -07:00
ComfyUI Wiki
0572029fee Update template to 0.1.88 (#10046) 2025-09-26 21:18:16 -07:00
144 changed files with 6884 additions and 13319 deletions

View File

@@ -0,0 +1,27 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@@ -0,0 +1,61 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
python_minor: "13"
python_patch: "6"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@@ -21,3 +21,28 @@ jobs:
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@@ -2,17 +2,17 @@
name: "Release Stable Version"
on:
workflow_dispatch:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cu:
description: 'CUDA version'
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "129"
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
@@ -23,7 +23,57 @@ on:
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs:
package_comfy_windows:
@@ -42,15 +92,15 @@ jobs:
id: cache
with:
path: |
cu${{ inputs.cu }}_python_deps.tar
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv cu${{ inputs.cu }}_python_deps.tar ../
mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar
tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd
ls
@@ -65,12 +115,19 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
cd ..
@@ -85,14 +142,18 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@@ -101,10 +162,9 @@ jobs:
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
uses: softprops/action-gh-release@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
draft: true
overwrite_files: true

View File

@@ -1,173 +0,0 @@
name: Asset System Tests
on:
push:
paths:
- 'app/**'
- 'tests-assets/**'
- '.github/workflows/test-assets.yml'
- 'requirements.txt'
pull_request:
branches: [master]
workflow_dispatch:
permissions:
contents: read
env:
PIP_DISABLE_PIP_VERSION_CHECK: '1'
PYTHONUNBUFFERED: '1'
jobs:
sqlite:
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
sqlite_mode: ['memory', 'file']
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for SQLite
id: setdb
shell: bash
run: |
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
else
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
mkdir -p "$(dirname "$DBFILE")"
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
fi
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn
postgres:
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
pgsql: ['16', '18']
services:
postgres:
image: postgres:${{ matrix.pgsql }}
env:
POSTGRES_DB: assets
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U postgres -d assets"
--health-interval 10s
--health-timeout 5s
--health-retries 12
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
pip install greenlet psycopg
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for PostgreSQL
shell: bash
run: |
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn

View File

@@ -56,7 +56,8 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@@ -0,0 +1,64 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause

View File

@@ -81,7 +81,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..

View File

@@ -176,6 +176,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@@ -200,14 +206,32 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux only)
### AMD GPUs (Linux)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux)
@@ -233,7 +257,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
#### Troubleshooting
@@ -264,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@@ -3,7 +3,7 @@
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = app/alembic_db
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time

View File

@@ -2,12 +2,13 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from app.assets.database.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,

View File

@@ -1,175 +0,0 @@
"""initial assets schema
Revision ID: 0001_assets
Revises:
Create Date: 2025-08-20 00:00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

View File

@@ -1,4 +0,0 @@
from .api.routes import register_assets_system
from .scanner import sync_seed_assets
__all__ = ["sync_seed_assets", "register_assets_system"]

View File

@@ -1,225 +0,0 @@
import contextlib
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Optional, Sequence
import folder_paths
from .api import schemas_in
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> Optional[str]:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None
def new_scan_id(root: schemas_in.RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

View File

@@ -1,544 +0,0 @@
import contextlib
import logging
import os
import urllib.parse
import uuid
from typing import Optional
from aiohttp import web
from pydantic import ValidationError
import folder_paths
from ... import user_manager
from .. import manager, scanner
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
USER_MANAGER: Optional[user_manager.UserManager] = None
LOGGER = logging.getLogger(__name__)
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = await manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
qp = request.rel_url.query
query_dict = {}
if "include_tags" in qp:
query_dict["include_tags"] = qp.getall("include_tags")
if "exclude_tags" in qp:
query_dict["exclude_tags"] = qp.getall("exclude_tags")
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
v = qp.get(k)
if v is not None:
query_dict[k] = v
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = await manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = await manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = await manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: Optional[str] = None
tags_raw: list[str] = []
provided_name: Optional[str] = None
user_metadata_raw: Optional[str] = None
provided_hash: Optional[str] = None
provided_hash_exists: Optional[bool] = None
file_written = 0
tmp_path: Optional[str] = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = await manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = await manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = await manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
async def set_asset_preview(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.SetPreviewBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.set_asset_preview(
asset_info_id=asset_info_id,
preview_asset_id=body.preview_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (PermissionError, ValueError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = await manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
LOGGER.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as ve:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
status=400,
)
result = await manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
try:
await scanner.sync_seed_assets(body.roots)
except Exception:
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response({"synced": True, "roots": body.roots}, status=200)
@ROUTES.post("/api/assets/scan/schedule")
async def schedule_asset_scan(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
states = await scanner.schedule_scans(body.roots)
return web.json_response(states.model_dump(mode="json"), status=202)
@ROUTES.get("/api/assets/scan")
async def get_asset_scan_status(request: web.Request) -> web.Response:
root = request.query.get("root", "").strip().lower()
states = scanner.current_statuses()
if root in {"models", "input", "output"}:
states = [s for s in states.scans if s.root == root] # type: ignore
states = schemas_out.AssetScanStatusResponse(scans=states)
return web.json_response(states.model_dump(mode="json"), status=200)
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})

View File

@@ -1,297 +0,0 @@
import json
import uuid
from typing import Any, Literal, Optional
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: Optional[str] = None
# Accept either a JSON string (query param) or a dict
metadata_filter: Optional[dict[str, Any]] = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class UpdateAssetBody(BaseModel):
name: Optional[str] = None
tags: Optional[list[str]] = None
user_metadata: Optional[dict[str, Any]] = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.tags is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, tags, user_metadata.")
if self.tags is not None:
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
raise ValueError("Field 'tags' must be an array of strings.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
v = v.strip()
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: Optional[str] = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@@ -1,115 +0,0 @@
from datetime import datetime
from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
preview_url: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: Optional[str]
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: Optional[str] = None
created_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class AssetScanError(BaseModel):
path: str
message: str
at: Optional[str] = Field(None, description="ISO timestamp")
class AssetScanStatus(BaseModel):
scan_id: str
root: Literal["models", "input", "output"]
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
scheduled_at: Optional[str] = None
started_at: Optional[str] = None
finished_at: Optional[str] = None
discovered: int = 0
processed: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)
class AssetScanStatusResponse(BaseModel):
scans: list[AssetScanStatus] = Field(default_factory=list)

View File

@@ -1,25 +0,0 @@
from .bulk_ops import seed_from_paths_batch
from .escape_like import escape_like_prefix
from .fast_check import fast_asset_file_check
from .filters import apply_metadata_filter, apply_tag_filters
from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv
from .tags import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
remove_missing_tag_for_asset_id,
)
__all__ = [
"apply_tag_filters",
"apply_metadata_filter",
"escape_like_prefix",
"fast_asset_file_check",
"is_scalar",
"project_kv",
"ensure_tags_exist",
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"seed_from_paths_batch",
"visible_owner_clause",
]

View File

@@ -1,230 +0,0 @@
import os
import uuid
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
from ..timeutil import utcnow
MAX_BIND_PARAMS = 800
async def seed_from_paths_batch(
session: AsyncSession,
*,
specs: Sequence[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
dialect = session.bind.dialect.name
if dialect not in ("sqlite", "postgresql"):
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
await session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
else:
ins_state = (
d_pg.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
else:
ins_info = (
d_pg.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
async def bulk_insert_tags_and_meta(
session: AsyncSession,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
dialect = session.bind.dialect.name
if tag_rows:
if dialect == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
await session.execute(ins_links, chunk)
if meta_rows:
if dialect == "sqlite":
ins_meta = (
d_sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif dialect == "postgresql":
ins_meta = (
d_pg.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
await session.execute(ins_meta, chunk)
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))

View File

@@ -1,7 +0,0 @@
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape

View File

@@ -1,19 +0,0 @@
import os
from typing import Optional
def fast_asset_file_check(
*,
mtime_db: Optional[int],
size_db: Optional[int],
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True

View File

@@ -1,87 +0,0 @@
from typing import Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Optional[Sequence[str]],
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@@ -1,12 +0,0 @@
import sqlalchemy as sa
from ..models import AssetInfo
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@@ -1,64 +0,0 @@
from decimal import Decimal
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -1,90 +0,0 @@
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def add_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def remove_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
) -> None:
await session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@@ -1,251 +0,0 @@
import uuid
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
from .timeutil import utcnow
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list["AssetCacheState"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Optional[Asset]] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -1,57 +0,0 @@
from .content import (
check_fs_asset_exists_quick,
compute_hash_and_dedup_for_cache_state,
ingest_fs_asset,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset,
touch_asset_infos_by_fs_path,
)
from .info import (
add_tags_to_asset_info,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_tags,
list_asset_infos_page,
list_tags_with_usage,
remove_tags_from_asset_info,
replace_asset_info_metadata_projection,
set_asset_info_preview,
set_asset_info_tags,
touch_asset_info_by_id,
update_asset_info_full,
)
from .queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
get_cache_state_by_asset_id,
list_cache_states_by_asset_id,
pick_best_live_path,
)
__all__ = [
# queries
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
"get_cache_state_by_asset_id",
"list_cache_states_by_asset_id",
"pick_best_live_path",
# info
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
"update_asset_info_full", "replace_asset_info_metadata_projection",
"touch_asset_info_by_id", "delete_asset_info_by_id",
"add_tags_to_asset_info", "remove_tags_from_asset_info",
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content
"check_fs_asset_exists_quick",
"redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
"list_cache_states_with_asset_under_prefixes",
]

View File

@@ -1,721 +0,0 @@
import contextlib
import logging
import os
from datetime import datetime
from typing import Any, Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from ..._helpers import compute_relative_filename
from ...storage import hashing as hashing_mod
from ..helpers import (
ensure_tags_exist,
escape_like_prefix,
remove_missing_tag_for_asset_id,
)
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
from .info import replace_asset_info_metadata_projection
from .queries import list_cache_states_by_asset_id, pick_best_live_path
async def check_fs_asset_exists_quick(
session: AsyncSession,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
locator = os.path.abspath(file_path)
stmt = (
sa.select(sa.literal(True))
.select_from(AssetCacheState)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(
AssetCacheState.file_path == locator,
Asset.hash.isnot(None),
AssetCacheState.needs_verify.is_(False),
)
.limit(1)
)
conds = []
if mtime_ns is not None:
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
if conds:
stmt = stmt.where(*conds)
return (await session.execute(stmt)).first() is not None
async def redirect_all_references_then_delete_asset(
session: AsyncSession,
*,
duplicate_asset_id: str,
canonical_asset_id: str,
) -> None:
"""
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
- Otherwise, simply repoint the AssetInfo.asset_id.
- Always retarget AssetCacheState rows.
- Finally delete the duplicate Asset row.
"""
if duplicate_asset_id == canonical_asset_id:
return
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
dup_infos = (
await session.execute(
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
)
).unique().scalars().all()
for info in dup_infos:
# Try to find an existing collision on canonical
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == canonical_asset_id,
AssetInfo.owner_id == info.owner_id,
AssetInfo.name == info.name,
)
.limit(1)
)
).unique().scalars().first()
if existing:
merged_meta = dict(existing.user_metadata or {})
other_meta = info.user_metadata or {}
for k, v in other_meta.items():
if k not in merged_meta:
merged_meta[k] = v
if merged_meta != (existing.user_metadata or {}):
await replace_asset_info_metadata_projection(
session,
asset_info_id=existing.id,
user_metadata=merged_meta,
)
existing_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
)
).all()
}
from_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
)
).all()
}
to_add = sorted(from_tags - existing_tags)
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
now = utcnow()
session.add_all([
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
for t in to_add
])
await session.flush()
if existing.preview_id is None and info.preview_id is not None:
existing.preview_id = info.preview_id
if info.last_access_time and (
existing.last_access_time is None or info.last_access_time > existing.last_access_time
):
existing.last_access_time = info.last_access_time
existing.updated_at = utcnow()
await session.flush()
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
await session.delete(info)
await session.flush()
else:
# Simple retarget
info.asset_id = canonical_asset_id
info.updated_at = utcnow()
await session.flush()
# 2) Repoint cache states and previews
await session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.asset_id == duplicate_asset_id)
.values(asset_id=canonical_asset_id)
)
await session.execute(
sa.update(AssetInfo)
.where(AssetInfo.preview_id == duplicate_asset_id)
.values(preview_id=canonical_asset_id)
)
# 3) Remove duplicate Asset
dup = await session.get(Asset, duplicate_asset_id)
if dup:
await session.delete(dup)
await session.flush()
async def compute_hash_and_dedup_for_cache_state(
session: AsyncSession,
*,
state_id: int,
) -> Optional[str]:
"""
Compute hash for the given cache state, deduplicate, and settle verify cases.
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
"""
state = await session.get(AssetCacheState, state_id)
if not state:
return None
path = state.file_path
try:
if not os.path.isfile(path):
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
asset = await session.get(Asset, state.asset_id)
await session.delete(state)
await session.flush()
if asset and asset.hash is None:
remaining = (
await session.execute(
sa.select(sa.func.count())
.select_from(AssetCacheState)
.where(AssetCacheState.asset_id == asset.id)
)
).scalar_one()
if int(remaining or 0) == 0:
await session.delete(asset)
await session.flush()
else:
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
return None
digest = await hashing_mod.blake3_hash(path)
new_hash = f"blake3:{digest}"
st = os.stat(path, follow_symlinks=True)
new_size = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
# Current asset of this state
this_asset = await session.get(Asset, state.asset_id)
# If the state got orphaned somehow (race), just reattach appropriately.
if not this_asset:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
state.asset_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
state.asset_id = new_asset.id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
await session.flush()
return state.asset_id
# 1) Seed asset case (hash is NULL): claim or merge into canonical
if this_asset.hash is None:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
# Merge seed asset into canonical (safe, collision-aware)
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# No canonical: try to claim the hash; handle races with a SAVEPOINT
try:
async with session.begin_nested():
this_asset.hash = new_hash
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
await session.flush()
except IntegrityError:
# Someone else claimed it concurrently; fetch canonical and merge
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# If we got here, the integrity error was not about hash uniqueness
raise
# Claimed successfully
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# 2) Verify case for hashed assets
if this_asset.hash == new_hash:
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
target_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
target_id = new_asset.id
state.asset_id = target_id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
await session.flush()
return target_id
except Exception:
raise
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
if session.bind.dialect.name == "postgresql":
stmt = (
sa.select(AssetCacheState.id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
.distinct(AssetCacheState.asset_id)
)
else:
first_id = sa.func.min(AssetCacheState.id).label("first_id")
stmt = (
sa.select(first_id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.group_by(AssetCacheState.asset_id)
.order_by(first_id.asc())
)
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
async def list_verify_candidates_under_prefixes(
session: AsyncSession, *, prefixes: Sequence[str]
) -> Union[list[int], Sequence[int]]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
return (
await session.execute(
sa.select(AssetCacheState.id)
.where(AssetCacheState.needs_verify.is_(True))
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: str = "",
preview_id: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
dialect = session.bind.dialect.name
if preview_id:
if not await session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
if dialect == "sqlite":
res = await session.execute(
d_sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
elif dialect == "postgresql":
res = await session.execute(
d_pg.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(
index_elements=[Asset.hash],
index_where=Asset.__table__.c.hash.isnot(None),
)
.returning(Asset.id)
)
inserted_id = res.scalar_one_or_none()
if inserted_id:
out["asset_created"] = True
asset = await session.get(Asset, inserted_id)
else:
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
res = await session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = await session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
async with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
await session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
await session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
await session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
await ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
await session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
file_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
locator = os.path.abspath(file_path)
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(AssetCacheState)
.where(
AssetCacheState.asset_id == AssetInfo.asset_id,
AssetCacheState.file_path == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
await session.execute(stmt.values(last_access_time=ts))
async def list_cache_states_with_asset_under_prefixes(
session: AsyncSession,
*,
prefixes: Sequence[str],
) -> list[tuple[AssetCacheState, Optional[str], int]]:
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
if not prefixes:
return []
conds = []
for p in prefixes:
if not p:
continue
base = os.path.abspath(p)
if not base.endswith(os.sep):
base = base + os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
if not conds:
return []
rows = (
await session.execute(
select(AssetCacheState, Asset.hash, Asset.size_bytes)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).all()
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
try:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
if not primary_path:
return
new_filename = compute_relative_filename(primary_path)
if not new_filename:
return
infos = (
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
).scalars().all()
for info in infos:
current_meta = info.user_metadata or {}
if current_meta.get("filename") == new_filename:
continue
updated = dict(current_meta)
updated["filename"] = new_filename
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
except Exception:
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)

View File

@@ -1,586 +0,0 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, noload
from ..._helpers import compute_relative_filename, normalize_tags
from ..helpers import (
apply_metadata_filter,
apply_tag_filters,
ensure_tags_exist,
escape_like_prefix,
project_kv,
visible_owner_clause,
)
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from ..timeutil import utcnow
from .queries import (
get_asset_by_hash,
list_cache_states_by_asset_id,
pick_best_live_path,
)
async def list_asset_infos_page(
session: AsyncSession,
*,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((await session.execute(count_stmt)).scalar_one() or 0)
infos = (await session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def fetch_asset_info_and_asset(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = await session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
async def fetch_asset_info_asset_and_tags(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (await session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
async def create_asset_info_for_existing_asset(
session: AsyncSession,
*,
asset_hash: str,
name: str,
user_metadata: Optional[dict] = None,
tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
async with session.begin_nested():
session.add(info)
await session.flush()
except IntegrityError:
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: str,
user_metadata: Optional[dict],
) -> None:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def touch_asset_info_by_id(
session: AsyncSession,
*,
asset_info_id: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
await session.execute(stmt.values(last_access_time=ts))
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((await session.execute(stmt)).rowcount or 0) > 0
async def add_tags_to_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
await ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
async with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
await session.flush()
except IntegrityError:
await nested.rollback()
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
async def remove_tags_from_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
await session.flush()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
async def list_tags_with_usage(
session: AsyncSession,
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (await session.execute(q.limit(limit).offset(offset))).all()
total = (await session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
async def set_asset_info_preview(
session: AsyncSession,
*,
asset_info_id: str,
preview_asset_id: Optional[str],
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not await session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
await session.flush()

View File

@@ -1,76 +0,0 @@
import os
from typing import Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
row = (
await session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
return (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (await session.execute(q)).first() is not None
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_id(
session: AsyncSession, *, asset_id: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path

View File

@@ -1,6 +0,0 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)

View File

@@ -1,556 +0,0 @@
import contextlib
import logging
import mimetypes
import os
from typing import Optional, Sequence
from comfy_api.internal import async_to_sync
from ..db import create_session
from ._helpers import (
ensure_within_base,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from .api import schemas_in, schemas_out
from .database.models import Asset
from .database.services import (
add_tags_to_asset_info,
asset_exists_by_hash,
asset_info_exists_for_asset_id,
check_fs_asset_exists_quick,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash,
get_asset_info_by_id,
get_asset_tags,
ingest_fs_asset,
list_asset_infos_page,
list_cache_states_by_asset_id,
list_tags_with_usage,
pick_best_live_path,
remove_tags_from_asset_info,
set_asset_info_preview,
touch_asset_info_by_id,
touch_asset_infos_by_fs_path,
update_asset_info_full,
)
from .storage import hashing
async def asset_exists(*, asset_hash: str) -> bool:
async with await create_session() as session:
return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
if not size_bytes:
return
async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
await session.commit()
return
asset_hash = hashing.blake3_hash_sync(abs_path)
async with await create_session() as session:
await ingest_fs_asset(
session,
asset_hash="blake3:" + asset_hash,
abs_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=file_name,
tag_origin="automatic",
tags=tags,
)
await session.commit()
async def list_assets(
*,
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
async with await create_session() as session:
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
async def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: Optional[str] = None,
owner_id: str = "",
expected_asset_hash: Optional[str] = None,
) -> schemas_out.AssetCreated:
try:
digest = await hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
async with await create_session() as session:
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = await create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
async with await create_session() as session:
result = await ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
async def update_asset(
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
await session.commit()
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
async def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: Optional[str],
owner_id: str = "",
) -> schemas_out.AssetDetail:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
await set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
await session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
await session.commit()
return False
if not delete_content_if_orphan or not asset_id:
await session.commit()
return True
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
await session.commit()
return True
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = await session.get(Asset, asset_id)
if asset_row is not None:
await session.delete(asset_row)
await session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
async def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower()
async with await create_session() as session:
asset = await get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = await create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
async def list_tags(
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
async with await create_session() as session:
rows, total = await list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
async def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
await session.commit()
return schemas_out.TagsAdd(**data)
async def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
await session.commit()
return schemas_out.TagsRemove(**data)
def _safe_sort_field(requested: Optional[str]) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str], fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback

View File

@@ -1,501 +0,0 @@
import asyncio
import contextlib
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Literal, Optional
import sqlalchemy as sa
import folder_paths
from ..db import create_session
from ._helpers import (
collect_models_files,
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
list_tree,
new_scan_id,
prefixes_for_root,
ts_to_iso,
)
from .api import schemas_in, schemas_out
from .database.helpers import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
escape_like_prefix,
fast_asset_file_check,
remove_missing_tag_for_asset_id,
seed_from_paths_batch,
)
from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import (
compute_hash_and_dedup_for_cache_state,
list_cache_states_by_asset_id,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
)
LOGGER = logging.getLogger(__name__)
SLOW_HASH_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: schemas_in.RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
file_errors: list[dict] = field(default_factory=list)
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
def current_statuses() -> schemas_out.AssetScanStatusResponse:
scans = []
for root in schemas_in.ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
scans.append(_scan_progress_to_scan_status_model(prog))
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
results: list[ScanProgress] = []
for root in roots:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
RUNNING_TASKS[root] = asyncio.create_task(
_run_hash_verify_pipeline(root, prog, state),
name=f"asset-scan:{root}",
)
results.append(prog)
return _status_response_for(results)
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
t_total = time.perf_counter()
created = 0
skipped_existing = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as ex:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
ap = os.path.abspath(p)
if ap in existing_paths:
skipped_existing += 1
continue
try:
st = os.stat(ap, follow_symlinks=True)
except OSError:
continue
if not st.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(ap)
specs.append(
{
"abs_path": ap,
"size_bytes": st.st_size,
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(ap),
}
)
for t in tags:
tag_pool.add(t)
if not specs:
return
async with await create_session() as sess:
if tag_pool:
await ensure_tags_exist(sess, tag_pool, tag_type="user")
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
await sess.commit()
finally:
LOGGER.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_total,
created,
skipped_existing,
len(paths),
)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=ts_to_iso(progress.scheduled_at),
started_at=ts_to_iso(progress.started_at),
finished_at=ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
prog.status = "running"
prog.started_at = time.time()
try:
prefixes = prefixes_for_root(root)
await _fast_db_consistency_pass(root)
# collect candidates from DB
async with await create_session() as sess:
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
# dedupe: prioritize verification first
seen = set()
ordered: list[int] = []
for lst in (verify_ids, unhashed_ids):
for sid in lst:
if sid not in seen:
seen.add(sid)
ordered.append(sid)
prog.discovered = len(ordered)
# queue up work
for sid in ordered:
await state.queue.put(sid)
state.closed = True
_start_state_workers(root, prog, state)
await _await_state_workers_then_finish(root, prog, state)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
_append_error(prog, path="", message=str(exc))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None)
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
"""
Detect missing files quickly and toggle 'missing' tag per asset_id.
Rules:
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
- We consider ALL cache states of the asset (across roots) before tagging.
"""
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
elif root == "input":
bases = [folder_paths.get_input_directory()]
else:
bases = [folder_paths.get_output_directory()]
try:
async with await create_session() as sess:
# state + hash + size for the current root
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
# Track fast_ok within the scanned root and whether the asset is hashed
by_asset: dict[str, dict[str, bool]] = {}
for state, a_hash, size_db in rows:
aid = state.asset_id
acc = by_asset.get(aid)
if acc is None:
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
by_asset[aid] = acc
try:
if acc["hashed"]:
st = os.stat(state.file_path, follow_symlinks=True)
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
acc["any_fast_ok_here"] = True
except FileNotFoundError:
pass
except OSError as e:
_append_error(prog, path=state.file_path, message=str(e))
# Decide per asset, considering ALL its states (not just this root)
for aid, acc in by_asset.items():
try:
if not acc["hashed"]:
# Never tag seed assets as missing
continue
any_fast_ok_global = acc["any_fast_ok_here"]
if not any_fast_ok_global:
# Check other states outside this root
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
for st in others:
try:
any_fast_ok_global = fast_asset_file_check(
mtime_db=st.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(st.file_path, follow_symlinks=True),
)
except OSError:
continue
if any_fast_ok_global:
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
else:
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception as ex:
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
await sess.commit()
except Exception as e:
_append_error(prog, path="", message=f"reconcile failed: {e}")
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
if state.workers:
return
async def _worker(_wid: int):
while True:
sid = await state.queue.get()
try:
if sid is None:
return
try:
async with await create_session() as sess:
# Optional: fetch path for better error messages
st = await sess.get(AssetCacheState, sid)
try:
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
await sess.commit()
except Exception as e:
path = st.file_path if st else f"state:{sid}"
_append_error(prog, path=path, message=str(e))
raise
except Exception:
pass
finally:
prog.processed += 1
finally:
state.queue.task_done()
state.workers = [
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
for i in range(SLOW_HASH_CONCURRENCY)
]
async def _close_when_ready():
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_ready())
async def _await_state_workers_then_finish(
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
await _reconcile_missing_tags_for_root(root, prog)
prog.finished_at = time.time()
prog.status = "completed"
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"at": ts_to_iso(time.time()),
})
async def _fast_db_consistency_pass(
root: schemas_in.RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> Optional[set[str]]:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
async with await create_session() as sess:
rows = (
await sess.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = await sess.get(Asset, aid)
if asset:
await sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
await sess.commit()
return survivors if collect_existing_paths else None

View File

@@ -1,72 +0,0 @@
import asyncio
import os
from typing import IO, Union
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
"""Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = None
if hasattr(file_obj, "tell"):
orig_pos = file_obj.tell()
try:
if hasattr(file_obj, "seek"):
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if hasattr(file_obj, "seek") and orig_pos is not None:
file_obj.seek(orig_pos)
def blake3_hash_sync(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
if hasattr(fp, "read"):
return _hash_file_obj_sync(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
async def blake3_hash(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
return await asyncio.to_thread(_worker)

112
app/database/db.py Normal file
View File

@@ -0,0 +1,112 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

14
app/database/models.py Normal file
View File

@@ -0,0 +1,14 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

255
app/db.py
View File

@@ -1,255 +0,0 @@
import logging
import os
import shutil
from contextlib import asynccontextmanager
from typing import Optional
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from comfy.cli_args import args
LOGGER = logging.getLogger(__name__)
ENGINE: Optional[AsyncEngine] = None
SESSION: Optional[async_sessionmaker] = None
def _root_paths():
"""Resolve alembic.ini and migrations script folder."""
root_path = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
return config_path, scripts_path
def _absolutize_sqlite_url(db_url: str) -> str:
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
try:
u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
db_path: str = u.database or ""
if isinstance(db_path, str) and db_path.startswith("file:"):
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
"""
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
Returns: (normalized_url, is_memory)
"""
try:
u = make_url(db_url)
except Exception:
return db_url, False
if not u.drivername.startswith("sqlite"):
return db_url, False
db = u.database or ""
if db == ":memory:":
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
return str(u), True
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
if "uri=true" not in db:
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
return str(u), True
return str(u), False
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
"""Return the on-disk path for a SQLite URL, else None."""
try:
u = make_url(sync_url)
except Exception:
return None
if not u.drivername.startswith("sqlite"):
return None
db_path = u.database
if isinstance(db_path, str) and db_path.startswith("file:"):
return None # Not a real file if it is a URI like "file:...?"
return db_path
def _get_alembic_config(sync_url: str) -> Config:
"""Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
db_url = _absolutize_sqlite_url(db_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
if is_mem:
connect_args["uri"] = True
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
if not is_mem:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(database_url=db_url, connect_args=connect_args)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(database_url: str, connect_args: dict) -> None:
if database_url.find("postgresql+psycopg") == -1:
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(database_url)
driver = u.drivername
if not driver.startswith("sqlite+aiosqlite"):
raise ValueError(f"Unsupported DB driver: {driver}")
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
database_url = _absolutize_sqlite_url(database_url)
cfg = _get_alembic_config(database_url)
engine = create_engine(database_url, future=True, connect_args=connect_args)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head()
if target_rev is None:
LOGGER.warning("Alembic: no target revision found.")
return
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(database_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try:
shutil.copy(sqlite_path, backup_path)
except Exception as exc:
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
try:
command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path)
except Exception as re:
LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
LOGGER.exception("Error upgrading database, backup is not available.")
raise
def get_engine():
"""Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@@ -42,6 +42,7 @@ def get_installed_frontend_version():
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
@@ -63,6 +64,7 @@ def get_required_frontend_version():
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
@@ -196,17 +198,6 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
@@ -214,17 +205,39 @@ class FrontendManager:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@@ -245,15 +258,6 @@ comfyui-frontend-package is not installed.
@classmethod
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
Returns:
str: The path to the workflow templates directory.
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@@ -289,16 +293,11 @@ comfyui-workflow-templates is not installed.
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str, str]: A tuple containing (owner, repo, version).
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
@@ -315,22 +314,18 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Initializes the frontend for the specified version.
Args:
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@@ -382,17 +377,13 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string specifying which frontend to use.
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)

View File

@@ -212,8 +212,7 @@ parser.add_argument(
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
@@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds
self.gs = fct * ic // oc
@@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1
h = self.conv(x)
if self.tds:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus
self.rp = fct * oc // ic
@@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1
h = self.conv(x)
if self.tus:
if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
for stage in self.down:
@@ -200,31 +216,42 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -235,25 +262,26 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@@ -264,4 +292,10 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))
out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module):
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
@@ -902,7 +903,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
num_heads: int,
need_global=True,
dtype=None,
device=None,

View File

@@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True

View File

@@ -123,16 +123,30 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
def __init__(self, key, patches):
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key):
set_func = None
@@ -657,13 +671,15 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1
cast_weight = True
@@ -825,10 +841,12 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True

View File

@@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@@ -58,15 +64,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma)
class X0(EPS):
@@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma
noise += latent_image
return noise

View File

@@ -416,8 +416,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:

View File

@@ -332,35 +332,51 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@@ -636,6 +652,7 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -651,6 +668,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -697,6 +721,7 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@@ -718,6 +743,13 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@@ -858,6 +890,7 @@ class TEModel(Enum):
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -880,6 +913,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -984,6 +1019,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)

View File

@@ -3,6 +3,7 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@@ -28,6 +29,9 @@ class Llama2Config:
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_3BConfig:
@@ -46,6 +50,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_7BVLI_Config:
@@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Gemma2_2B_Config:
@@ -82,6 +92,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None
@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -106,25 +142,40 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out = []
for index, t in enumerate(theta):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
return (cos, sin)
if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
@@ -152,6 +203,14 @@ class Attention(nn.Module):
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
self.k_norm = None
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
@@ -168,6 +227,11 @@ class Attention(nn.Module):
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@@ -192,7 +256,7 @@ class MLP(nn.Module):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@@ -226,7 +290,7 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
self.transformer_type = config.transformer_type
def forward(
self,
x: torch.Tensor,
@@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
# Self Attention
residual = x
x = self.input_layernorm(x)
@@ -276,7 +355,7 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
@@ -284,8 +363,8 @@ class Llama2_(nn.Module):
self.normalize_in = False
self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@@ -305,6 +384,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
@@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
@@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_

View File

@@ -50,16 +50,10 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
@@ -72,8 +66,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@@ -101,8 +93,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
# populate_db_with_asset(ckpt) # surprise tool that can help us later - performs hashing on model file
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@@ -114,6 +114,8 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
comfy_io = io # create the new alias for io
__all__ = [
"ComfyAPI",
"ComfyAPISync",
@@ -121,4 +123,7 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"comfy_io",
"ui",
]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod
def save_to(
self,
path: str,
path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None

View File

@@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput):
"""Combo input (dropdown)."""
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
upload: UploadType=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None):
def __init__(
self,
id: str,
options: list[str] | list[int] | type[Enum] = None,
display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False
self.options = options
@@ -392,20 +406,6 @@ class MultiCombo(ComfyTypeI):
})
return to_return
@comfytype(io_type="ASSET")
class Asset(ComfyTypeI):
class Input(WidgetInput):
def __init__(self, id: str, query_tags: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.query_tags = query_tags
def as_dict(self):
to_return = super().as_dict() | prune_dict({
"query_tags": self.query_tags
})
return to_return
@comfytype(io_type="IMAGE")
class Image(ComfyTypeIO):
Type = torch.Tensor
@@ -1582,77 +1582,78 @@ class _UIOutput(ABC):
...
class _IO:
FolderType = FolderType
UploadType = UploadType
RemoteOptions = RemoteOptions
NumberDisplay = NumberDisplay
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
comfytype = staticmethod(comfytype)
Custom = staticmethod(Custom)
Input = Input
WidgetInput = WidgetInput
Output = Output
ComfyTypeI = ComfyTypeI
ComfyTypeIO = ComfyTypeIO
#---------------------------------
"comfytype",
"Custom",
"Input",
"WidgetInput",
"Output",
"ComfyTypeI",
"ComfyTypeIO",
# Supported Types
Boolean = Boolean
Int = Int
Float = Float
String = String
Combo = Combo
MultiCombo = MultiCombo
Image = Image
WanCameraEmbedding = WanCameraEmbedding
Webcam = Webcam
Mask = Mask
Latent = Latent
Conditioning = Conditioning
Sampler = Sampler
Sigmas = Sigmas
Noise = Noise
Guider = Guider
Clip = Clip
ControlNet = ControlNet
Vae = Vae
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel
Audio = Audio
Video = Video
SVG = SVG
LoraModel = LoraModel
LossMap = LossMap
Voxel = Voxel
Mesh = Mesh
Hooks = Hooks
HookKeyframes = HookKeyframes
TimestepsRange = TimestepsRange
LatentOperation = LatentOperation
FlowControl = FlowControl
Accumulation = Accumulation
Load3DCamera = Load3DCamera
Load3D = Load3D
Load3DAnimation = Load3DAnimation
Photomaker = Photomaker
Point = Point
FaceAnalysis = FaceAnalysis
BBOX = BBOX
SEGS = SEGS
AnyType = AnyType
MultiType = MultiType
#---------------------------------
HiddenHolder = HiddenHolder
Hidden = Hidden
NodeInfoV1 = NodeInfoV1
NodeInfoV3 = NodeInfoV3
Schema = Schema
ComfyNode = ComfyNode
NodeOutput = NodeOutput
add_to_dict_v1 = staticmethod(add_to_dict_v1)
add_to_dict_v3 = staticmethod(add_to_dict_v3)
"Boolean",
"Int",
"Float",
"String",
"Combo",
"MultiCombo",
"Image",
"WanCameraEmbedding",
"Webcam",
"Mask",
"Latent",
"Conditioning",
"Sampler",
"Sigmas",
"Noise",
"Guider",
"Clip",
"ControlNet",
"Vae",
"Model",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",
"Gligen",
"UpscaleModel",
"Audio",
"Video",
"SVG",
"LoraModel",
"LossMap",
"Voxel",
"Mesh",
"Hooks",
"HookKeyframes",
"TimestepsRange",
"LatentOperation",
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3D",
"Load3DAnimation",
"Photomaker",
"Point",
"FaceAnalysis",
"BBOX",
"SEGS",
"AnyType",
"MultiType",
# Other classes
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
]

View File

@@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText
__all__ = [
"SavedResult",
"SavedImages",
"SavedAudios",
"ImageSaveHelper",
"AudioSaveHelper",
"PreviewImage",
"PreviewMask",
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewText",
]

View File

@@ -18,7 +18,7 @@ from comfy_api_nodes.apis.client import (
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
@@ -30,7 +30,9 @@ from io import BytesIO
import av
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
@@ -39,7 +41,7 @@ async def download_url_to_video_output(video_url: str, timeout: int = None) -> V
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout)
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
@@ -152,7 +154,7 @@ def validate_aspect_ratio(
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
elif calculated_ratio > maximum_ratio:
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
@@ -164,7 +166,9 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
@@ -174,9 +178,18 @@ async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
@@ -256,7 +269,7 @@ def tensor_to_bytesio(
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data.
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
@@ -418,7 +431,7 @@ async def upload_video_to_comfyapi(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error(f"Error getting video duration: {e}")
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"

View File

@@ -2,6 +2,7 @@
# filename: filtered-openapi.yaml
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations
from datetime import date, datetime
@@ -1320,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum):
@@ -1354,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum):
kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoResult(BaseModel):

View File

@@ -95,9 +95,10 @@ import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
@@ -174,7 +175,7 @@ class ApiClient:
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
@@ -198,9 +199,9 @@ class ApiClient:
@staticmethod
def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
data: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return {
"json": data,
"headers": headers,
@@ -208,24 +209,27 @@ class ApiClient:
def _create_form_data_args(
self,
data: Dict[str, Any] | None,
files: Dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None,
data: dict[str, Any] | None,
files: dict[str, Any] | None,
headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if isinstance(data, aiohttp.FormData):
form = data # If the parser already returned a FormData, pass it through
else:
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
@@ -250,9 +254,9 @@ class ApiClient:
@staticmethod
def _create_urlencoded_form_data_args(
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
data: dict[str, Any],
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
@@ -260,7 +264,7 @@ class ApiClient:
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
@@ -271,7 +275,7 @@ class ApiClient:
return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
@@ -312,14 +316,14 @@ class ApiClient:
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
@@ -355,10 +359,10 @@ class ApiClient:
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug("[DEBUG] Files: %s", files)
logging.debug("[DEBUG] Params: %s", params)
logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
@@ -481,7 +485,7 @@ class ApiClient:
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: Dict[str, str] = {}
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
@@ -499,7 +503,9 @@ class ApiClient:
else:
raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
@@ -532,7 +538,7 @@ class ApiClient:
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
@@ -552,7 +558,7 @@ class ApiClient:
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
@@ -586,9 +592,9 @@ class ApiClient:
error_message=f"HTTP Error {exc.status}",
)
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}")
logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
@@ -653,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
@@ -678,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
@@ -723,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
)
try:
request_dict: Optional[Dict[str, Any]]
request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
@@ -732,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
response_json = await client.request(
self.endpoint.method.value,
@@ -751,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response
finally:
if owns_client:
@@ -778,14 +782,14 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] | None = None,
result_url_extractor: Callable[[R], str] | None = None,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
@@ -871,7 +875,7 @@ class PollingOperation(Generic[T, R]):
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
@@ -879,10 +883,13 @@ class PollingOperation(Generic[T, R]):
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
"[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
"[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
)
# Query task status
@@ -897,7 +904,7 @@ class PollingOperation(Generic[T, R]):
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress
if self.progress_extractor:
@@ -911,7 +918,7 @@ class PollingOperation(Generic[T, R]):
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}")
logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
@@ -919,7 +926,7 @@ class PollingOperation(Generic[T, R]):
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
logging.error("[DEBUG] %s", message)
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
@@ -933,7 +940,12 @@ class PollingOperation(Generic[T, R]):
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
@@ -943,10 +955,13 @@ class PollingOperation(Generic[T, R]):
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)

View File

@@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@@ -4,62 +4,99 @@ import os
import datetime
import json
import logging
import re
import hashlib
from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
"""Ensures the API log directory exists within ComfyUI's temp directory and returns its path."""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
os.makedirs(log_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}")
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
# Fallback to base temp directory if sub-directory creation fails
return base_temp_dir
return log_dir
def _format_data_for_logging(data):
def _sanitize_filename_component(name: str) -> str:
if not name:
return "log"
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore
sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed
if not sanitized:
sanitized = "log"
return sanitized
def _short_hash(*parts: str, length: int = 10) -> str:
return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length]
def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str:
"""Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id
h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL
# Compute how much room we have for the slug given the directory length
# Keep total path length reasonably below ~260 on Windows.
max_total_path = 240
prefix = f"{timestamp}_"
suffix = f"_{h}.log"
if not slug:
slug = "op"
max_filename_len = max(60, max_total_path - len(log_dir) - 1)
max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix))
if len(slug) > max_slug_len:
slug = slug[:max_slug_len].rstrip(" ._-")
return os.path.join(log_dir, f"{prefix}{slug}{suffix}")
def _format_data_for_logging(data: Any) -> str:
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode('utf-8') # Try to decode as text
return data.decode("utf-8") # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: any = None,
request_data: Any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: any = None,
error_message: str | None = None
response_content: Any = None,
error_message: str | None = None,
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
@@ -69,7 +106,7 @@ def log_request_response(
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data:
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
@@ -77,7 +114,7 @@ def log_request_response(
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content:
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
@@ -85,9 +122,10 @@ def log_request_response(
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}")
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
logger.error("Error writing API log to %s: %s", filepath, str(e))
if __name__ == '__main__':
# Example usage (for testing the logger directly)

View File

@@ -52,7 +52,3 @@ class RodinResourceItem(BaseModel):
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

File diff suppressed because it is too large Load Diff

View File

@@ -249,8 +249,8 @@ class ByteDanceImageNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2ImageModelName],
default=Text2ImageModelName.seedream_3.value,
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -382,8 +382,8 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2ImageModelName],
default=Image2ImageModelName.seededit_3.value,
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
comfy_io.Image.Input(
@@ -676,8 +676,8 @@ class ByteDanceTextToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2VideoModelName],
default=Text2VideoModelName.seedance_1_pro.value,
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -793,8 +793,8 @@ class ByteDanceImageToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_pro.value,
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -920,7 +920,7 @@ class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[Image2VideoModelName.seedance_1_lite.value],
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
),

View File

@@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string,
bytesio_to_image_tensor,
)
from comfy_api.util import VideoContainer, VideoCodec
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
@@ -310,7 +311,7 @@ class GeminiNode(ComfyNodeABC):
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
@@ -490,7 +491,6 @@ class GeminiInputFiles(ComfyNodeABC):
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
@@ -51,174 +52,186 @@ def image_result_url_extractor(response: LumaGeneration):
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC):
class LumaReferenceNode(comfy_io.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
RETURN_TYPES = (LumaIO.LUMA_REF,)
RETURN_NAMES = ("luma_ref",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_luma_reference"
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as reference.",
),
comfy_io.Float.Input(
"weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of image reference.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"luma_ref",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as reference.",
},
),
"weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of image reference.",
},
),
},
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
}
def create_luma_reference(
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
):
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> comfy_io.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return (luma_ref,)
return comfy_io.NodeOutput(luma_ref)
class LumaConceptsNode(ComfyNodeABC):
class LumaConceptsNode(comfy_io.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
RETURN_NAMES = ("luma_concepts",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_concepts"
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"concept1",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept2",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept3",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept4",
options=get_luma_concepts(include_none=True),
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to add to the ones chosen here.",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"concept1": (get_luma_concepts(include_none=True),),
"concept2": (get_luma_concepts(include_none=True),),
"concept3": (get_luma_concepts(include_none=True),),
"concept4": (get_luma_concepts(include_none=True),),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
},
),
},
}
def create_concepts(
self,
def execute(
cls,
concept1: str,
concept2: str,
concept3: str,
concept4: str,
luma_concepts: LumaConceptChain = None,
):
) -> comfy_io.NodeOutput:
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain)
return (chain,)
return comfy_io.NodeOutput(chain)
class LumaImageGenerationNode(ComfyNodeABC):
class LumaImageGenerationNode(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Combo.Input(
"model",
options=LumaImageModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Float.Input(
"style_image_weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of style image. Ignored if no style_image provided.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"image_luma_ref",
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
optional=True,
),
comfy_io.Image.Input(
"style_image",
tooltip="Style reference image; only 1 image will be used.",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
optional=True,
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"model": ([model.value for model in LumaImageModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
"style_image_weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of style image. Ignored if no style_image provided.",
},
),
},
"optional": {
"image_luma_ref": (
LumaIO.LUMA_REF,
{
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
},
),
"style_image": (
IO.IMAGE,
{"tooltip": "Style reference image; only 1 image will be used."},
),
"character_image": (
IO.IMAGE,
{
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
aspect_ratio: str,
@@ -227,27 +240,29 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs,
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs,
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
@@ -268,7 +283,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
@@ -283,18 +298,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
return comfy_io.NodeOutput(img)
@classmethod
async def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
@@ -308,82 +324,84 @@ class LumaImageGenerationNode(ComfyNodeABC):
break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC):
class LumaImageModifyNode(comfy_io.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Float.Input(
"image_weight",
default=0.1,
min=0.0,
max=0.98,
step=0.01,
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
),
comfy_io.Combo.Input(
"model",
options=LumaImageModel,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"image_weight": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 0.98,
"step": 0.01,
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
},
),
"model": ([model.value for model in LumaImageModel],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
image: torch.Tensor,
image_weight: float,
seed,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs,
image, max_images=1, auth_kwargs=auth_kwargs,
)
image_url = download_urls[0]
# next, make Luma call with download url provided
@@ -401,7 +419,7 @@ class LumaImageModifyNode(ComfyNodeABC):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
@@ -416,88 +434,84 @@ class LumaImageModifyNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
return comfy_io.NodeOutput(img)
class LumaTextToVideoGenerationNode(ComfyNodeABC):
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=LumaVideoModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
"resolution",
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=LumaVideoModelOutputDuration,
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
aspect_ratio: str,
@@ -506,13 +520,15 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
@@ -529,12 +545,12 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@@ -547,90 +563,94 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(ComfyNodeABC):
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=LumaVideoModel,
),
# comfy_io.Combo.Input(
# "aspect_ratio",
# options=[ratio.value for ratio in LumaAspectRatio],
# default=LumaAspectRatio.ratio_16_9,
# ),
comfy_io.Combo.Input(
"resolution",
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Image.Input(
"first_image",
tooltip="First frame of generated video.",
optional=True,
),
comfy_io.Image.Input(
"last_image",
tooltip="Last frame of generated video.",
optional=True,
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
# "default": LumaAspectRatio.ratio_16_9,
# }),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"first_image": (
IO.IMAGE,
{"tooltip": "First frame of generated video."},
),
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
resolution: str,
@@ -640,14 +660,16 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@@ -668,12 +690,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@@ -686,18 +708,19 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod
async def _convert_to_keyframes(
self,
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
@@ -719,23 +742,18 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
return LumaKeyframes(frame0=frame0, frame1=frame1)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"LumaImageNode": LumaImageGenerationNode,
"LumaImageModifyNode": LumaImageModifyNode,
"LumaVideoNode": LumaTextToVideoGenerationNode,
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
"LumaReferenceNode": LumaReferenceNode,
"LumaConceptsNode": LumaConceptsNode,
}
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
LumaImageGenerationNode,
LumaImageModifyNode,
LumaTextToVideoGenerationNode,
LumaImageToVideoGenerationNode,
LumaReferenceNode,
LumaConceptsNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"LumaImageNode": "Luma Text to Image",
"LumaImageModifyNode": "Luma Image to Image",
"LumaVideoNode": "Luma Text to Video",
"LumaImageToVideoNode": "Luma Image to Video",
"LumaReferenceNode": "Luma Reference",
"LumaConceptsNode": "Luma Concepts",
}
async def comfy_entrypoint() -> LumaExtension:
return LumaExtension()

View File

@@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"

View File

@@ -2,11 +2,7 @@ import logging
from typing import Any, Callable, Optional, TypeVar
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
@@ -132,47 +128,6 @@ def validate_prompts(
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@@ -282,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
@@ -292,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
@@ -301,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
@@ -333,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
@@ -353,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
@@ -380,7 +331,7 @@ def parse_width_height_from_res(resolution: str):
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
@@ -433,11 +384,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
@@ -448,14 +399,14 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
default=4.5,
min=1.0,
max=20.0,
step=1.0,
@@ -469,10 +420,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
comfy_io.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
@@ -499,7 +451,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_input_image(image, True)
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
@@ -513,12 +465,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
@@ -571,11 +522,11 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Int.Input(
@@ -591,7 +542,7 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
comfy_io.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
"control_type",
@@ -608,6 +559,15 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=33,
min=1,
max=100,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
@@ -627,6 +587,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
steps=33,
prompt_adherence=4.5,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
@@ -636,7 +598,6 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer
@@ -648,6 +609,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
control = parse_control_parameter(control_type)
@@ -699,11 +662,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
@@ -721,7 +684,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
default=4.0,
min=1.0,
max=20.0,
step=1.0,
@@ -734,11 +697,12 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
comfy_io.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,7 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from io import BytesIO
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -26,12 +28,11 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
validate_string,
)
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
import torch
import aiohttp
from io import BytesIO
AVERAGE_DURATION_T2V = 32
@@ -72,100 +73,101 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
return response_upload.Resp.img_id
class PixverseTemplateNode:
class PixverseTemplateNode(comfy_io.ComfyNode):
"""
Select template for PixVerse Video generation.
"""
RETURN_TYPES = (PixverseIO.TEMPLATE,)
RETURN_NAMES = ("pixverse_template",)
FUNCTION = "create_template"
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())),
],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()),),
}
}
def create_template(self, template: str):
def execute(cls, template: str) -> comfy_io.NodeOutput:
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return (template_id,)
return comfy_io.NodeOutput(template_id)
class PixverseTextToVideoNode(ComfyNodeABC):
class PixverseTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"aspect_ratio",
options=PixverseAspectRatio,
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
aspect_ratio: str,
quality: str,
@@ -174,9 +176,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -186,6 +186,10 @@ class PixverseTextToVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
@@ -203,7 +207,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -224,8 +228,8 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@@ -233,77 +237,75 @@ class PixverseTextToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(ComfyNodeABC):
class PixverseImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
image: torch.Tensor,
prompt: str,
quality: str,
@@ -312,11 +314,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -343,7 +347,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -364,8 +368,8 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
@@ -373,72 +377,71 @@ class PixverseImageToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(ComfyNodeABC):
class PixverseTransitionVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("first_frame"),
comfy_io.Image.Input("last_frame"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
@@ -447,12 +450,14 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
motion_mode: str,
seed,
negative_prompt: str = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -479,7 +484,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -500,8 +505,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@@ -509,19 +514,19 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
NODE_CLASS_MAPPINGS = {
"PixverseTextToVideoNode": PixverseTextToVideoNode,
"PixverseImageToVideoNode": PixverseImageToVideoNode,
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
"PixverseTemplateNode": PixverseTemplateNode,
}
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
PixverseTextToVideoNode,
PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video",
"PixverseImageToVideoNode": "PixVerse Image to Video",
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}
async def comfy_entrypoint() -> PixVerseExtension:
return PixVerseExtension()

View File

@@ -35,57 +35,64 @@ from server import PromptServer
import torch
from io import BytesIO
from PIL import UnidentifiedImageError
import aiohttp
async def handle_recraft_file_request(
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict:
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Formats data such that multipart/form-data will work with requests library
when both files and data are present.
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
def recraft_multipart_parser(
data,
parent_key=None,
formatter: callable = None,
converted_to_check: list[list] = None,
is_list: bool = False,
return_mode: str = "formdata" # "dict" | "formdata"
) -> dict | aiohttp.FormData:
"""
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
The OpenAI client that Recraft uses has a bizarre way of serializing lists:
@@ -103,23 +110,23 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
# Modification of a function that handled a different type of multipart parsing, big ups:
# https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b
def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]):
def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]):
# if list already exists exists, just extend list with data
for check_list in lists_to_check:
for conv_tuple in check_list:
if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list:
conv_tuple[1].append(formatter(data))
if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list):
conv_tuple[1].append(formatter(item))
return True
return False
if converted_to_check is None:
converted_to_check = []
effective_mode = return_mode if parent_key is None else "dict"
if formatter is None:
formatter = lambda v: v # Multipart representation of value
if type(data) is not dict:
if not isinstance(data, dict):
# if list already exists exists, just extend list with data
added = handle_converted_lists(data, parent_key, converted_to_check)
if added:
@@ -136,15 +143,24 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
for key, value in data.items():
current_key = key if parent_key is None else f"{parent_key}[{key}]"
if type(value) is dict:
if isinstance(value, dict):
converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items())
elif type(value) is list:
elif isinstance(value, list):
for ind, list_value in enumerate(value):
iter_key = f"{current_key}[]"
converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items())
else:
converted.append((current_key, formatter(value)))
if effective_mode == "formdata":
fd = aiohttp.FormData()
for k, v in dict(converted).items():
if isinstance(v, list):
for item in v:
fd.add_field(k, str(item))
else:
fd.add_field(k, str(v))
return fd
return dict(converted)

File diff suppressed because it is too large Load Diff

View File

@@ -200,11 +200,11 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed",
@@ -300,11 +300,11 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio],
options=RunwayGen4TurboAspectRatio,
),
comfy_io.Int.Input(
"seed",
@@ -408,11 +408,11 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed",

View File

@@ -0,0 +1,175 @@
from typing import Optional
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
seconds: str = Field(...)
size: str = Field(...)
class Sora2GenerationResponse(BaseModel):
id: str = Field(...)
error: Optional[dict] = Field(None)
status: Optional[str] = Field(None)
class OpenAIVideoSora2(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video",
category="api node/video/Sora",
description="OpenAI video and audio generation.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["sora-2", "sora-2-pro"],
default="sora-2",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Guiding text; may be empty if an input image is present.",
),
comfy_io.Combo.Input(
"size",
options=[
"720x1280",
"1280x720",
"1024x1792",
"1792x1024",
],
default="1280x720",
),
comfy_io.Combo.Input(
"duration",
options=[4, 8, 12],
default=8,
),
comfy_io.Image.Input(
"image",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
optional=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
size: str = "1280x720",
duration: int = 8,
seed: int = 0,
image: Optional[torch.Tensor] = None,
):
if model == "sora-2" and size not in ("720x1280", "1280x720"):
raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.")
files_input = None
if image is not None:
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload = Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/v1/videos",
method=HttpMethod.POST,
request_model=Sora2GenerationRequest,
response_model=Sora2GenerationResponse
),
request=payload,
files=files_input,
auth_kwargs=auth,
content_type="multipart/form-data",
)
initial_response = await initial_operation.execute()
if initial_response.error:
raise Exception(initial_response.error.message)
model_time_multiplier = 1 if model == "sora-2" else 2
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/openai/v1/videos/{initial_response.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=Sora2GenerationResponse
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda x: x.status,
auth_kwargs=auth,
poll_interval=8.0,
max_poll_attempts=160,
node_id=cls.hidden.unique_id,
estimated_duration=45 * (duration / 4) * model_time_multiplier,
)
await poll_operation.execute()
return comfy_io.NodeOutput(
await download_url_to_video_output(
f"/proxy/openai/v1/videos/{initial_response.id}/content",
auth_kwargs=auth,
)
)
class OpenAISoraExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
OpenAIVideoSora2,
]
async def comfy_entrypoint() -> OpenAISoraExtension:
return OpenAISoraExtension()

View File

@@ -82,8 +82,8 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
@@ -217,12 +217,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[x.value for x in Stability_SD3_5_Model],
options=Stability_SD3_5_Model,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(

View File

@@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}")
logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function
def status_extractor(response):

View File

@@ -173,8 +173,8 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -205,22 +205,22 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
options=MovementAmplitude,
default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@@ -278,8 +278,8 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
@@ -316,14 +316,14 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
options=MovementAmplitude,
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
@@ -388,8 +388,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
@@ -424,8 +424,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),

View File

@@ -28,6 +28,12 @@ class Text2ImageInputField(BaseModel):
negative_prompt: Optional[str] = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
@@ -49,6 +55,13 @@ class Txt2ImageParametersField(BaseModel):
watermark: bool = Field(True)
class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
@@ -73,6 +86,12 @@ class Text2ImageTaskCreationRequest(BaseModel):
parameters: Txt2ImageParametersField = Field(...)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2ImageInputField = Field(...)
parameters: Image2ImageParametersField = Field(...)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
@@ -135,7 +154,12 @@ async def process_task(
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
node_id: str,
estimated_duration: int,
poll_interval: int,
@@ -288,6 +312,128 @@ class WanTextToImageApi(comfy_io.ComfyNode):
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanImageToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2i-preview"],
default="wan2.5-i2i-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
# redo this later as an optional combo of recommended resolutions
# comfy_io.Int.Input(
# "width",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
# comfy_io.Int.Input(
# "height",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
payload = Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=42,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanTextToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -593,6 +739,7 @@ class WanApiExtension(ComfyExtension):
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
WanTextToImageApi,
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
]

View File

@@ -360,7 +360,7 @@ class RecordAudio:
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )

View File

@@ -1,44 +1,62 @@
import folder_paths
import comfy.audio_encoders.audio_encoders
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AudioEncoderLoader:
class AudioEncoderLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
}}
RETURN_TYPES = ("AUDIO_ENCODER",)
FUNCTION = "load_model"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderLoader",
category="loaders",
inputs=[
io.Combo.Input(
"audio_encoder_name",
options=folder_paths.get_filename_list("audio_encoders"),
),
],
outputs=[io.AudioEncoder.Output()],
)
CATEGORY = "loaders"
def load_model(self, audio_encoder_name):
@classmethod
def execute(cls, audio_encoder_name) -> io.NodeOutput:
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,)
return io.NodeOutput(audio_encoder)
class AudioEncoderEncode:
class AudioEncoderEncode(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
FUNCTION = "encode"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderEncode",
category="conditioning",
inputs=[
io.AudioEncoder.Input("audio_encoder"),
io.Audio.Input("audio"),
],
outputs=[io.AudioEncoderOutput.Output()],
)
CATEGORY = "conditioning"
def encode(self, audio_encoder, audio):
@classmethod
def execute(cls, audio_encoder, audio) -> io.NodeOutput:
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,)
return io.NodeOutput(output)
NODE_CLASS_MAPPINGS = {
"AudioEncoderLoader": AudioEncoderLoader,
"AudioEncoderEncode": AudioEncoderEncode,
}
class AudioEncoder(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AudioEncoderLoader,
AudioEncoderEncode,
]
async def comfy_entrypoint() -> AudioEncoder:
return AudioEncoder()

View File

@@ -1,6 +1,9 @@
import torch
import comfy.utils
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha
class PorterDuffImageComposite:
class PorterDuffImageComposite(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source": ("IMAGE",),
"source_alpha": ("MASK",),
"destination": ("IMAGE",),
"destination_alpha": ("MASK",),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
},
}
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
io.Image.Input("source"),
io.Mask.Input("source_alpha"),
io.Image.Input("destination"),
io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "composite"
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
@classmethod
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = []
out_alphas = []
@@ -150,45 +157,48 @@ class PorterDuffImageComposite:
out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
class SplitImageWithAlpha:
class SplitImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
@classmethod
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
class JoinImageWithAlpha:
class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []
@@ -196,19 +206,18 @@ class JoinImageWithAlpha:
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result
return io.NodeOutput(torch.stack(out_images))
NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite,
"SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha,
}
class CompositingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}
async def comfy_entrypoint() -> CompositingExtension:
return CompositingExtension()

View File

@@ -1,34 +1,41 @@
# code adapted from https://github.com/exx8/differential-diffusion
from typing_extensions import override
import torch
from comfy_api.latest import ComfyExtension, io
class DifferentialDiffusion():
class DifferentialDiffusion(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
},
"optional": {
"strength": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
def define_schema(cls):
return io.Schema(
node_id="DifferentialDiffusion",
display_name="Differential Diffusion",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
optional=True,
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
def apply(self, model, strength=1.0):
@classmethod
def execute(cls, model, strength=1.0) -> io.NodeOutput:
model = model.clone()
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
return (model, )
model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
return io.NodeOutput(model)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
@classmethod
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"]
step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min
@@ -53,9 +60,13 @@ class DifferentialDiffusion():
return binary_mask
NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}
class DifferentialDiffusionExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
DifferentialDiffusion,
]
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
return DifferentialDiffusionExtension()

View File

@@ -1,26 +1,38 @@
import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ReferenceLatent:
class ReferenceLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
},
"optional": {"latent": ("LATENT", ),}
}
def define_schema(cls):
return io.Schema(
node_id="ReferenceLatent",
category="advanced/conditioning/edit_models",
description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Conditioning.Output(),
]
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
@classmethod
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return (conditioning, )
return io.NodeOutput(conditioning)
NODE_CLASS_MAPPINGS = {
"ReferenceLatent": ReferenceLatent,
}
class EditModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ReferenceLatent,
]
def comfy_entrypoint() -> EditModelExtension:
return EditModelExtension()

74
comfy_extras/nodes_eps.py Normal file
View File

@@ -0,0 +1,74 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class EpsilonScaling(io.ComfyNode):
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Epsilon Scaling",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"scaling_factor",
default=1.005,
min=0.5,
max=1.5,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scaling_factor) -> io.NodeOutput:
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return io.NodeOutput(model_clone)
class EpsilonScalingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
]
async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()

View File

@@ -1,60 +1,80 @@
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeFlux:
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeFlux",
category="advanced/conditioning/flux",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning/flux"
def encode(self, clip, clip_l, t5xxl, guidance):
@classmethod
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
class FluxGuidance:
encode = execute # TODO: remove
class FluxGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxGuidance",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
@classmethod
def execute(cls, conditioning, guidance) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
class FluxDisableGuidance:
class FluxDisableGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxDisableGuidance",
category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[
io.Conditioning.Input("conditioning"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
@classmethod
def execute(cls, conditioning) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
@@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
]
class FluxKontextImageScale:
class FluxKontextImageScale(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE", ),
},
}
def define_schema(cls):
return io.Schema(
node_id="FluxKontextImageScale",
category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "scale"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
@classmethod
def execute(cls, image) -> io.NodeOutput:
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, )
return io.NodeOutput(image)
scale = execute # TODO: remove
class FluxKontextMultiReferenceLatentMethod:
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
@classmethod
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
return io.NodeOutput(c)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
}
append = execute # TODO: remove
class FluxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeFlux,
FluxGuidance,
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
]
async def comfy_entrypoint() -> FluxExtension:
return FluxExtension()

View File

@@ -1,6 +1,8 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@@ -333,25 +335,28 @@ NOISE_LEVELS = {
],
}
class GITSScheduler:
class GITSScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
def define_schema(cls):
return io.Schema(
node_id="GITSScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
io.Int.Input("steps", default=10, min=2, max=1000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas"
def get_sigmas(self, coeff, steps, denoise):
@classmethod
def execute(cls, coeff, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise)
if steps <= 20:
@@ -362,8 +367,16 @@ class GITSScheduler:
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler,
}
class GITSSchedulerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()

View File

@@ -1,55 +1,73 @@
from typing_extensions import override
import folder_paths
import comfy.sd
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
class QuadrupleCLIPLoader:
class QuadrupleCLIPLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
def define_schema(cls):
return io.Schema(
node_id="QuadrupleCLIPLoader",
category="advanced/loaders",
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
]
)
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
@classmethod
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
return io.NodeOutput(clip)
class CLIPTextEncodeHiDream:
class CLIPTextEncodeHiDream(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHiDream",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.String.Input("llama", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
]
)
@classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
tokens = clip.tokenize(clip_g)
tokens["l"] = clip.tokenize(clip_l)["l"]
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
tokens["llama"] = clip.tokenize(llama)["llama"]
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
}
class HiDreamExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
QuadrupleCLIPLoader,
CLIPTextEncodeHiDream,
]
async def comfy_entrypoint() -> HiDreamExtension:
return HiDreamExtension()

Some files were not shown because too many files have changed in this diff Show More