Compare commits

...

246 Commits

Author SHA1 Message Date
comfyanonymous
59afc39848 ComfyUI v0.3.77 2025-12-03 00:02:09 -05:00
Alexander Piskun
028e17dd7a add check for the format arg type in VideoFromComponents.save_to function (#11046)
* add check for the format var type in VideoFromComponents.save_to function

* convert "format" to VideoContainer enum
2025-12-03 00:01:24 -05:00
comfyanonymous
30c259cac8 ComfyUI version v0.3.76 2025-12-01 20:25:35 -05:00
Alexander Piskun
1cb7e22a95 [API Nodes] add Kling O1 model support (#11025)
* feat(api-nodes): add Kling O1 model support

* fix: increase max allowed duration to 10.05 seconds

* fix(VideoInput): respect "format" argument
2025-12-01 16:11:52 -08:00
comfyanonymous
2640acb31c Update qwen tokenizer to add qwen 3 tokens. (#11029)
Doesn't actually change anything for current workflows because none of the
current models have a template with the think tokens.
2025-12-01 17:13:48 -05:00
Christian Byrne
7dbd5dfe91 bump comfyui-frontend-package to 1.32.10 (#11018) 2025-12-01 13:27:17 -05:00
comfyanonymous
f8b981ae9a Next AMD portable will have pytorch with ROCm 7.1.1 (#11002) 2025-11-30 04:21:31 -05:00
ComfyUI Wiki
4967f81778 update template to 0.7.25 (#10996)
* update template to 0.7.24

* Update template to 0.7.25
2025-11-29 18:07:26 -08:00
comfyanonymous
0a6746898d Make the ScaleRope node work on Z Image and Lumina. (#10994) 2025-11-29 18:00:55 -05:00
comfyanonymous
5151cff293 Add some missing z image lora layers. (#10980) 2025-11-28 23:55:00 -05:00
Dr.Lt.Data
af96d9812d feat(security): add System User protection with __ prefix (#10966)
* feat(security): add System User protection with `__` prefix

Add protected namespace for custom nodes to store sensitive data
(API keys, licenses) that cannot be accessed via HTTP endpoints.

Key changes:
- New API: get_system_user_directory() for internal access
- New API: get_public_user_directory() with structural blocking
- 3-layer defense: header validation, path blocking, creation prevention
- 54 tests covering security, edge cases, and backward compatibility

System Users use `__` prefix (e.g., __system, __cache) following
Python's private member convention. They exist in user_directory/
but are completely blocked from /userdata HTTP endpoints.

* style: remove unused imports
2025-11-28 21:28:42 -05:00
comfyanonymous
52a32e2b32 Support some z image lora formats. (#10978) 2025-11-28 21:12:42 -05:00
Jukka Seppänen
b907085709 Support video tiny VAEs (#10884)
* Support video tiny VAEs

* lighttaew scaling fix

* Also support video taes in previews

Only first frame for now as live preview playback is currently only available through VHS custom nodes.

* Support Wan 2.1 lightVAE

* Relocate elif block and set Wan VAE dim directly without using pruning rate for lightvae
2025-11-28 19:40:19 -05:00
comfyanonymous
065a2fbbec Update driver link in AMD portable README (#10974) 2025-11-28 19:37:39 -05:00
rattus
0ff0457892 mm: wrap the raw stream in context manager (#10958)
The documentation of torch.foo.Stream being usable with with: suggests
it starts at version 2.7. Use the old API for backwards compatibility.
2025-11-28 16:38:12 -05:00
Urle Sistiana
6484ac89dc fix QuantizedTensor.is_contiguous (#10956) (#10959) 2025-11-28 16:33:07 -05:00
comfyanonymous
f55c98a89f Disable offload stream when torch compile. (#10961) 2025-11-28 16:16:46 -05:00
Dr.Lt.Data
ca7808f240 fix(user_manager): fix typo in move_userdata dest validation (#10967)
Check `dest` instead of `source` when validating destination path
in move_userdata endpoint.
2025-11-28 12:43:17 -08:00
Alexander Piskun
52e778fff3 feat(Kling-API-Nodes): add v2-5-turbo model to FirstLastFrame node (#10938) 2025-11-28 02:52:59 -08:00
comfyanonymous
9d8a817985 Enable async offloading by default on Nvidia. (#10953)
Add --disable-async-offload to disable it.

If this causes OOMs that go away when you --disable-async-offload please
report it.
2025-11-27 17:46:12 -05:00
ComfyUI Wiki
b59750a86a Update template to 0.7.23 (#10949) 2025-11-27 17:12:56 -05:00
rattus
3f382a4f98 quant ops: Dequantize weight in-place (#10935)
In flux2 these weights are huge (200MB). As plain_tensor is a throw-away
deep copy, do this multiplication in-place to save VRAM.
2025-11-27 08:06:30 -08:00
rattus
f17251bec6 Account for the VRAM cost of weight offloading (#10733)
* mm: default to 0 for NUM_STREAMS

Dont count the compute stream as an offload stream. This makes async
offload accounting easier.

* mm: remove 128MB minimum

This is from a previous offloading system requirement. Remove it to
make behaviour of the loader and partial unloader consistent.

* mp: order the module list by offload expense

Calculate an approximate offloading temporary VRAM cost to offload a
weight and primary order the module load list by that. In the simple
case this is just the same as the module weight, but with Loras, a
weight with a lora consumes considerably more VRAM to do the Lora
application on-the-fly.

This will slightly prioritize lora weights, but is really for
proper VRAM offload accounting.

* mp: Account for the VRAM cost of weight offloading

when checking the VRAM headroom, assume that the weight needs to be
offloaded, and only load if it has space for both the load and offload
 * the number of streams.

As the weights are ordered from largest to smallest by offload cost
this is guaranteed to fit in VRAM (tm), as all weights that follow
will be smaller.

Make the partial unload aware of this system as well by saving the
budget for offload VRAM to the model state and accounting accordingly.
Its possible that partial unload increases the size of the largest
offloaded weights, and thus needs to unload a little bit more than
asked to accomodate the bigger temp buffers.

Honor the existing codes floor on model weight loading of 128MB by
having the patcher honor this separately withough regard to offloading.
Otherwise when MM specifies its 128MB minimum, MP will see the biggest
weights, and budget that 128MB to only offload buffer and load nothing
which isnt the intent of these minimums. The same clamp applies in
case of partial offload of the currently loading model.
2025-11-27 01:03:03 -05:00
Haoming
c38e7d6599 block info (#10841) 2025-11-26 20:28:44 -08:00
comfyanonymous
eaf68c9b5b Make lora training work on Z Image and remove some redundant nodes. (#10927) 2025-11-26 19:25:32 -05:00
Kohaku-Blueleaf
cc6a8dcd1a Dataset Processing Nodes and Improved LoRA Trainer Nodes with multi resolution supports. (#10708)
* Create nodes_dataset.py

* Add encoded dataset caching mechanism

* make training node to work with our dataset system

* allow trainer node to get different resolution dataset

* move all dataset related implementation to nodes_dataset

* Rewrite dataset system with new io schema

* Rewrite training system with new io schema

* add ui pbar

* Add outputs' id/name

* Fix bad id/naming

* use single process instead of input list when no need

* fix wrong output_list flag

* use torch.load/save and fix bad behaviors
2025-11-26 19:18:08 -05:00
Alexander Piskun
a2d60aad0f convert nodes_customer_sampler.py to V3 schema (#10206) 2025-11-26 14:55:31 -08:00
Alexander Piskun
d8433c63fd chore(api-nodes): remove chat widgets from OpenAI/Gemini nodes (#10861) 2025-11-26 14:42:01 -08:00
comfyanonymous
dd41b74549 Add Z Image to readme. (#10924) 2025-11-26 15:36:38 -05:00
comfyanonymous
55f654db3d Fix the CSP offline feature. (#10923) 2025-11-26 15:16:40 -05:00
Terry Jia
58c6ed541d Merge 3d animation node (#10025) 2025-11-26 14:58:27 -05:00
Christian Byrne
234c3dc85f Bump frontend to 1.32.9 (#10867) 2025-11-26 14:58:08 -05:00
Alexander Piskun
8908ee2628 fix(gemini): use first 10 images as fileData (URLs) and remaining images as inline base64 (#10918) 2025-11-26 10:38:30 -08:00
Alexander Piskun
1105e0d139 improve UX for batch uploads in upload_images_to_comfyapi (#10913) 2025-11-26 09:23:14 -08:00
Alexander Piskun
8938aa3f30 add Veo3 First-Last-Frame node (#10878) 2025-11-26 09:14:02 -08:00
comfyanonymous
f16219e3aa Add cheap latent preview for flux 2. (#10907)
Thank you to the person who calculated them. You saved me a percent of my
time.
2025-11-26 04:00:43 -05:00
comfyanonymous
8402c8700a ComfyUI version v0.3.75 2025-11-26 02:41:13 -05:00
comfyanonymous
58b8574661 Fix Flux2 reference image mem estimation. (#10905) 2025-11-26 02:36:19 -05:00
comfyanonymous
90b3995ec8 ComfyUI v0.3.74 2025-11-26 00:34:15 -05:00
comfyanonymous
bdb10a583f Fix loras not working on mixed fp8. (#10899) 2025-11-26 00:07:58 -05:00
comfyanonymous
0e24dbb19f Adjustments to Z Image. (#10893) 2025-11-25 19:02:51 -05:00
comfyanonymous
e9aae31fa2 Z Image model. (#10892) 2025-11-25 18:41:45 -05:00
comfyanonymous
0c18842acb ComfyUI v0.3.73 2025-11-25 14:59:37 -05:00
comfyanonymous
d196a905bb Lower vram usage for flux 2 text encoder. (#10887) 2025-11-25 14:58:39 -05:00
ComfyUI Wiki
18b79acba9 Update workflow templates to v0.7.20 (#10883) 2025-11-25 14:58:21 -05:00
comfyanonymous
dff996ca39 Fix crash. (#10885) 2025-11-25 14:30:24 -05:00
comfyanonymous
828b1b9953 ComfyUI version v0.3.72 2025-11-25 12:40:58 -05:00
comfyanonymous
af81cb962d Add Flux 2 support to README. (#10882) 2025-11-25 11:40:32 -05:00
Alexander Piskun
5c7b08ca58 [API Nodes] add Flux.2 Pro node (#10880) 2025-11-25 11:09:07 -05:00
comfyanonymous
6b573ae0cb Flux 2 (#10879) 2025-11-25 10:50:19 -05:00
comfyanonymous
015a0599d0 I found a case where this is needed (#10875) 2025-11-25 03:23:19 -05:00
comfyanonymous
acfaa5c4a1 Don't try fp8 matrix mult in quantized ops if not supported by hardware. (#10874) 2025-11-25 02:55:49 -05:00
comfyanonymous
b6805429b9 Allow pinning quantized tensors. (#10873) 2025-11-25 02:48:20 -05:00
comfyanonymous
25022e0b09 Cleanup and fix issues with text encoder quants. (#10872) 2025-11-25 01:48:53 -05:00
comfyanonymous
22a2644e57 Bump transformers version in requirements.txt (#10869) 2025-11-24 19:45:54 -05:00
Haoming
b2ef58e2b1 block info (#10844) 2025-11-24 10:40:09 -08:00
Haoming
6a6d456c88 block info (#10842) 2025-11-24 10:38:38 -08:00
Haoming
3d1fdaf9f4 block info (#10843) 2025-11-24 10:30:40 -08:00
Alexander Piskun
1286fcfe40 add get_frame_count and get_frame_rate methods to VideoInput class (#10851) 2025-11-24 10:24:29 -08:00
Alexander Piskun
3bd71554a2 fix(api-nodes): edge cases in responses for Gemini models (#10860) 2025-11-24 09:48:37 -08:00
guill
f66183a541 [fix] Fixes non-async public API access (#10857)
It looks like the synchronous version of the public API broke due to an
addition of `from __future__ import annotations`. This change updates
the async-to-sync adapter to work with both types of type annotations.
2025-11-23 22:56:20 -08:00
comfyanonymous
cbd68e3d58 Add better error message for common error. (#10846) 2025-11-23 04:55:22 -05:00
comfyanonymous
d89c29f259 Add display names to Hunyuan latent video nodes. (#10837) 2025-11-22 22:51:53 -05:00
Christian Byrne
a9c35256bc Update requirements.txt (#10834) 2025-11-22 02:28:29 -08:00
comfyanonymous
532938b16b --disable-api-nodes now sets CSP header to force frontend offline. (#10829) 2025-11-21 17:51:55 -05:00
Christian Byrne
ecb683b057 update frontend to 1.30 (#10793) 2025-11-21 16:34:47 -05:00
comfyanonymous
c55fd74816 ComfyUI 0.3.71 2025-11-21 00:49:13 -05:00
comfyanonymous
3398123752 Fix wrong path. (#10821) 2025-11-20 23:39:37 -05:00
comfyanonymous
943b3b615d HunyuanVideo 1.5 (#10819)
* init

* update

* Update model.py

* Update model.py

* remove print

* Fix text encoding

* Prevent empty negative prompt

Really doesn't work otherwise

* fp16 works

* I2V

* Update model_base.py

* Update nodes_hunyuan.py

* Better latent rgb factors

* Use the correct sigclip output...

* Support HunyuanVideo1.5 SR model

* whitespaces...

* Proper latent channel count

* SR model fixes

This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already

* vae_refiner: roll the convolution through temporal

Work in progress.

Roll the convolution through time using 2-latent-frame chunks and a
FIFO queue for the convolution seams.

* Support HunyuanVideo15 latent resampler

* fix

* Some cleanup

Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>

* Proper hyvid15 I2V channels

Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>

* Fix TokenRefiner for fp16

Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary.

* Bugfix for the HunyuanVideo15 SR model

* vae_refiner: roll the convolution through temporal II

Roll the convolution through time using 2-latent-frame chunks and a
FIFO queue for the convolution seams.

Added support for encoder, lowered to 1 latent frame to save more
VRAM, made work for Hunyuan Image 3.0 (as code shared).

Fixed names, cleaned up code.

* Allow any number of input frames in VAE.

* Better VAE encode mem estimation.

* Lowvram fix.

* Fix hunyuan image 2.1 refiner.

* Fix mistake.

* Name changes.

* Rename.

* Whitespace.

* Fix.

* Fix.

---------

Co-authored-by: kijai <40791699+kijai@users.noreply.github.com>
Co-authored-by: Rattus <rattus128@gmail.com>
2025-11-20 22:44:43 -05:00
Christian Byrne
10e90a5757 bump comfyui-workflow-templates for nano banana 2 (#10818)
* bump templates

* bump templates
2025-11-20 18:20:52 -08:00
Alexander Piskun
b75d349f25 fix(KlingLipSyncAudioToVideoNode): convert audio to mp3 format (#10811) 2025-11-20 16:33:54 -08:00
Alexander Piskun
7b8389578e feat(api-nodes): add Nano Banana Pro (#10814)
* feat(api-nodes): add Nano Banana Pro

* frontend bump to 1.28.9
2025-11-20 16:17:47 -08:00
Jedrzej Kosinski
9e00ce5b76 Make Batch Images node add alpha channel when one of the inputs has it (#10816)
* When one Batch Image input has alpha and one does not, add empty alpha channel

* Use torch.nn.functional.pad
2025-11-20 17:42:46 -05:00
comfyanonymous
f5e66d5e47 Fix ImageBatch with different channel count. (#10815) 2025-11-20 15:08:03 -05:00
Christian Byrne
87b0359392 Update server templates handler to use new multi-package distribution (comfyui-workflow-templates versions >=0.3) (#10791)
* update templates for monorepo

* refactor
2025-11-19 22:36:56 -08:00
comfyanonymous
cb96d4d18c Disable workaround on newer cudnn. (#10807) 2025-11-19 23:56:23 -05:00
Alexander Piskun
394348f5ca feat(api-nodes): add Topaz API nodes (#10755) 2025-11-19 17:44:04 -08:00
comfyanonymous
7601e89255 Fix workflow name. (#10806) 2025-11-19 20:17:15 -05:00
Alexander Piskun
6a1d3a1ae1 convert hunyuan3d.py to V3 schema (#10664) 2025-11-19 14:49:01 -08:00
Alexander Piskun
65ee24c978 change display name of PreviewAny node to "Preview as Text" (#10796) 2025-11-19 01:25:28 -08:00
comfyanonymous
17027f2a6a Add a way to disable the final norm in the llama based TE models. (#10794) 2025-11-18 22:36:03 -05:00
comfyanonymous
b5c8be8b1d ComfyUI 0.3.70 2025-11-18 19:37:20 -05:00
Alexander Piskun
24fdb92edf feat(api-nodes): add new Gemini model (#10789) 2025-11-18 14:26:44 -08:00
comfyanonymous
d526974576 Fix hunyuan 3d 2.0 (#10792) 2025-11-18 16:46:19 -05:00
Jukka Seppänen
e1ab6bb394 EasyCache: Fix for mismatch in input/output channels with some models (#10788)
Slices model input with output channels so the caching tracks only the noise channels, resolves channel mismatch with models like WanVideo I2V

Also fix for slicing deprecation in pytorch 2.9
2025-11-18 07:00:21 -08:00
Alexander Piskun
048f49adbd chore(api-nodes): adjusted PR template; set min python version for pylint to 3.10 (#10787) 2025-11-18 03:59:27 -08:00
comfyanonymous
47bfd5a33f Native block swap custom nodes considered harmful. (#10783) 2025-11-18 00:26:44 -05:00
ComfyUI Wiki
fdf49a2861 Fix the portable download link for CUDA 12.6 (#10780) 2025-11-17 22:04:06 -05:00
comfyanonymous
f41e5f398d Update README with new portable download link (#10778) 2025-11-17 19:59:19 -05:00
comfyanonymous
27cbac865e Add release workflow for NVIDIA cu126 (#10777) 2025-11-17 19:04:04 -05:00
comfyanonymous
3d0003c24c ComfyUI version 0.3.69 2025-11-17 17:17:24 -05:00
comfyanonymous
7d6103325e Change ROCm nightly install command to 7.1 (#10764) 2025-11-16 03:01:14 -05:00
Alexander Piskun
2d4a08b717 Revert "chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)" (#10759)
This reverts commit 9a02382568.
2025-11-15 12:37:34 -08:00
Alexander Piskun
9a02382568 chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757) 2025-11-15 11:18:49 -08:00
comfyanonymous
bd01d9f7fd Add left padding support to tokenizers. (#10753) 2025-11-15 06:54:40 -05:00
comfyanonymous
443056c401 Fix custom nodes import error. (#10747)
This should fix the import errors but will break if the custom nodes actually try to use the class.
2025-11-14 03:26:05 -05:00
comfyanonymous
f60923590c Use same code for chroma and flux blocks so that optimizations are shared. (#10746) 2025-11-14 01:28:05 -05:00
comfyanonymous
1ef328c007 Better instructions for the portable. (#10743) 2025-11-13 21:32:39 -05:00
rattus
94c298f962 flux: reduce VRAM usage (#10737)
Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22
for 1600x1600 on RTX5090.
2025-11-13 16:02:03 -08:00
ric-yu
2fde9597f4 feat: add create_time dict to prompt field in /history and /queue (#10741) 2025-11-13 15:11:52 -08:00
Alexander Piskun
f91078b1ff add PR template for API-Nodes (#10736) 2025-11-13 10:05:26 -08:00
contentis
3b3ef9a77a Quantized Ops fixes (#10715)
* offload support, bug fixes, remove mixins

* add readme
2025-11-12 18:26:52 -05:00
comfyanonymous
8b0b93df51 Update Python 3.14 compatibility notes in README (#10730) 2025-11-12 17:04:41 -05:00
rattus
1c7eaeca10 qwen: reduce VRAM usage (#10725)
Clean up a bunch of stacked and no-longer-needed tensors on the QWEN
VRAM peak (currently FFN).

With this I go from OOMing at B=37x1328x1328 to being able to
succesfully run B=47 (RTX5090).
2025-11-12 16:20:53 -05:00
rattus
18e7d6dba5 mm/mp: always unload re-used but modified models (#10724)
The partial unloader path in model re-use flow skips straight to the
actual unload without any check of the patching UUID. This means that
if you do an upscale flow with a model patch on an existing model, it
will not apply your patchings.

Fix by delaying the partial_unload until after the uuid checks. This
is done by making partial_unload a model of partial_load where extra_mem
is -ve.
2025-11-12 16:19:53 -05:00
Qiacheng Li
e1d85e7577 Update README.md for Intel Arc GPU installation, remove IPEX (#10729)
IPEX is no longer needed for Intel Arc GPUs.  Removing instruction to setup ipex.
2025-11-12 15:21:05 -05:00
comfyanonymous
1199411747 Don't pin tensor if not a torch.nn.parameter.Parameter (#10718) 2025-11-11 19:33:30 -05:00
comfyanonymous
5ebcab3c7d Update CI workflow to remove dead macOS runner. (#10704)
* Update CI workflow to remove dead macOS runner.

* revert

* revert
2025-11-10 15:35:29 -05:00
rattus
c350009236 ops: Put weight cast on the offload stream (#10697)
This needs to be on the offload stream. This reproduced a black screen
with low resolution images on a slow bus when using FP8.
2025-11-09 22:52:11 -05:00
comfyanonymous
dea899f221 Unload weights if vram usage goes up between runs. (#10690) 2025-11-09 18:51:33 -05:00
comfyanonymous
e632e5de28 Add logging for model unloading. (#10692) 2025-11-09 18:06:39 -05:00
comfyanonymous
2abd2b5c20 Make ScaleROPE node work on Flux. (#10686) 2025-11-08 15:52:02 -05:00
comfyanonymous
a1a70362ca Only unpin tensor if it was pinned by ComfyUI (#10677) 2025-11-07 11:15:05 -05:00
rattus
cf97b033ee mm: guard against double pin and unpin explicitly (#10672)
As commented, if you let cuda be the one to detect double pin/unpinning
it actually creates an asyc GPU error.
2025-11-06 21:20:48 -05:00
comfyanonymous
eb1c42f649 Tell users they need to upload their logs in bug reports. (#10671) 2025-11-06 20:24:28 -05:00
comfyanonymous
e05c907126 Clarify release cycle. (#10667) 2025-11-06 04:11:30 -05:00
comfyanonymous
09dc24c8a9 Pinned mem also seems to work on AMD. (#10658) 2025-11-05 19:11:15 -05:00
comfyanonymous
1d69245981 Enable pinned memory by default on Nvidia. (#10656)
Removed the --fast pinned_memory flag.

You can use --disable-pinned-memory to disable it. Please report if it
causes any issues.
2025-11-05 18:08:13 -05:00
comfyanonymous
97f198e421 Fix qwen controlnet regression. (#10657) 2025-11-05 18:07:35 -05:00
Alexander Piskun
bda0eb2448 feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645) 2025-11-05 02:16:00 -08:00
comfyanonymous
c4a6b389de Lower ltxv mem usage to what it was before previous pr. (#10643)
Bring back qwen behavior to what it was before previous pr.
2025-11-04 22:47:35 -05:00
contentis
4cd881866b Use single apply_rope function across models (#10547) 2025-11-04 20:10:11 -05:00
comfyanonymous
265adad858 ComfyUI version v0.3.68 2025-11-04 19:42:23 -05:00
comfyanonymous
7f3e4d486c Limit amount of pinned memory on windows to prevent issues. (#10638) 2025-11-04 17:37:50 -05:00
rattus
a389ee01bb caching: Handle None outputs tuple case (#10637) 2025-11-04 14:14:10 -08:00
ComfyUI Wiki
9c71a66790 chore: update workflow templates to v0.2.11 (#10634) 2025-11-04 10:51:53 -08:00
comfyanonymous
af4b7b5edb More fp8 torch.compile regressions fixed. (#10625) 2025-11-03 22:14:20 -05:00
comfyanonymous
0f4ef3afa0 This seems to slow things down slightly on Linux. (#10624) 2025-11-03 21:47:14 -05:00
comfyanonymous
6b88478f9f Bring back fp8 torch compile performance to what it should be. (#10622) 2025-11-03 19:22:10 -05:00
comfyanonymous
e199c8cc67 Fixes (#10621) 2025-11-03 17:58:24 -05:00
comfyanonymous
0652cb8e2d Speed up torch.compile (#10620) 2025-11-03 17:37:12 -05:00
comfyanonymous
958a17199a People should update their pytorch versions. (#10618) 2025-11-03 17:08:30 -05:00
ComfyUI Wiki
e974e554ca chore: update embedded docs to v0.3.1 (#10614) 2025-11-03 10:59:44 -08:00
Alexander Piskun
4e2110c794 feat(Pika-API-nodes): use new API client (#10608) 2025-11-03 00:29:08 -08:00
Alexander Piskun
e617cddf24 convert nodes_openai.py to V3 schema (#10604) 2025-11-03 00:28:13 -08:00
Alexander Piskun
1f3f7a2823 convert nodes_hypernetwork.py to V3 schema (#10583) 2025-11-03 00:21:47 -08:00
EverNebula
88df172790 fix(caching): treat bytes as hashable (#10567) 2025-11-03 00:16:40 -08:00
Alexander Piskun
6d6a18b0b7 fix(api-nodes-cloud): stop using sub-folder and absolute path for output of Rodin3D nodes (#10556) 2025-11-03 00:04:56 -08:00
comfyanonymous
97ff9fae7e Clarify help text for --fast argument (#10609)
Updated help text for the --fast argument to clarify potential risks.
2025-11-02 13:14:04 -05:00
rattus
135fa49ec2 Small speed improvements to --async-offload (#10593)
* ops: dont take an offload stream if you dont need one

* ops: prioritize mem transfer

The async offload streams reason for existence is to transfer from
RAM to GPU. The post processing compute steps are a bonus on the side
stream, but if the compute stream is running a long kernel, it can
stall the side stream, as it wait to type-cast the bias before
transferring the weight. So do a pure xfer of the weight straight up,
then do everything bias, then go back to fix the weight type and do
weight patches.
2025-11-01 18:48:53 -04:00
comfyanonymous
44869ff786 Fix issue with pinned memory. (#10597) 2025-11-01 17:25:59 -04:00
Alexander Piskun
20182a393f convert StabilityAI to use new API client (#10582) 2025-11-01 12:14:06 -07:00
Alexander Piskun
5f109fe6a0 added 12s-20s as available output durations for the LTXV API nodes (#10570) 2025-11-01 12:13:39 -07:00
comfyanonymous
c58c13b2ba Fix torch compile regression on fp8 ops. (#10580) 2025-11-01 00:25:17 -04:00
comfyanonymous
7f374e42c8 ScaleROPE now works on Lumina models. (#10578) 2025-10-31 15:41:40 -04:00
comfyanonymous
27d1bd8829 Fix rope scaling. (#10560) 2025-10-30 22:51:58 -04:00
comfyanonymous
614cf9805e Add a ScaleROPE node. Currently only works on WAN models. (#10559) 2025-10-30 22:11:38 -04:00
rattus
513b0c46fb Add RAM Pressure cache mode (#10454)
* execution: Roll the UI cache into the outputs

Currently the UI cache is parallel to the output cache with
expectations of being a content superset of the output cache.
At the same time the UI and output cache are maintained completely
seperately, making it awkward to free the output cache content without
changing the behaviour of the UI cache.

There are two actual users (getters) of the UI cache. The first is
the case of a direct content hit on the output cache when executing a
node. This case is very naturally handled by merging the UI and outputs
cache.

The second case is the history JSON generation at the end of the prompt.
This currently works by asking the cache for all_node_ids and then
pulling the cache contents for those nodes. all_node_ids is the nodes
of the dynamic prompt.

So fold the UI cache into the output cache. The current UI cache setter
now writes to a prompt-scope dict. When the output cache is set, just
get this value from the dict and tuple up with the outputs.

When generating the history, simply iterate prompt-scope dict.

This prepares support for more complex caching strategies (like RAM
pressure caching) where less than 1 workflow will be cached and it
will be desirable to keep the UI cache and output cache in sync.

* sd: Implement RAM getter for VAE

* model_patcher: Implement RAM getter for ModelPatcher

* sd: Implement RAM getter for CLIP

* Implement RAM Pressure cache

Implement a cache sensitive to RAM pressure. When RAM headroom drops
down below a certain threshold, evict RAM-expensive nodes from the
cache.

Models and tensors are measured directly for RAM usage. An OOM score
is then computed based on the RAM usage of the node.

Note the due to indirection through shared objects (like a model
patcher), multiple nodes can account the same RAM as their individual
usage. The intent is this will free chains of nodes particularly
model loaders and associate loras as they all score similar and are
sorted in close to each other.

Has a bias towards unloading model nodes mid flow while being able
to keep results like text encodings and VAE.

* execution: Convert the cache entry to NamedTuple

As commented in review.

Convert this to a named tuple and abstract away the tuple type
completely from graph.py.
2025-10-30 17:39:02 -04:00
Alexander Piskun
dfac94695b fix img2img operation in Dall2 node (#10552) 2025-10-30 10:22:35 -07:00
Alexander Piskun
163b629c70 use new API client in Pixverse and Ideogram nodes (#10543) 2025-10-29 23:49:03 -07:00
Jedrzej Kosinski
998bf60beb Add units/info for the numbers displayed on 'load completely' and 'load partially' log messages (#10538) 2025-10-29 19:37:06 -04:00
comfyanonymous
906c089957 Fix small performance regression with fp8 fast and scaled fp8. (#10537) 2025-10-29 19:29:01 -04:00
comfyanonymous
25de7b1bfa Try to fix slow load issue on low ram hardware with pinned mem. (#10536) 2025-10-29 17:20:27 -04:00
rattus
ab7ab5be23 Fix Race condition in --async-offload that can cause corruption (#10501)
* mm: factor out the current stream getter

Make this a reusable function.

* ops: sync the offload stream with the consumption of w&b

This sync is nessacary as pytorch will queue cuda async frees on the
same stream as created to tensor. In the case of async offload, this
will be on the offload stream.

Weights and biases can go out of scope in python which then
triggers the pytorch garbage collector to queue the free operation on
the offload stream possible before the compute stream has used the
weight. This causes a use after free on weight data leading to total
corruption of some workflows.

So sync the offload stream with the compute stream after the weight
has been used so the free has to wait for the weight to be used.

The cast_bias_weight is extended in a backwards compatible way with
the new behaviour opt-in on a defaulted parameter. This handles
custom node packs calling cast_bias_weight and defeatures
async-offload for them (as they do not handle the race).

The pattern is now:

cast_bias_weight(... , offloadable=True) #This might be offloaded
thing(weight, bias, ...)
uncast_bias_weight(...)

* controlnet: adopt new cast_bias_weight synchronization scheme

This is nessacary for safe async weight offloading.

* mm: sync the last stream in the queue, not the next

Currently this peeks ahead to sync the next stream in the queue of
streams with the compute stream. This doesnt allow a lot of
parallelization, as then end result is you can only get one weight load
ahead regardless of how many streams you have.

Rotate the loop logic here to synchronize the end of the queue before
returning the next stream. This allows weights to be loaded ahead of the
compute streams position.
2025-10-29 17:17:46 -04:00
comfyanonymous
ec4fc2a09a Fix case of weights not being unpinned. (#10533) 2025-10-29 15:48:06 -04:00
comfyanonymous
1a58087ac2 Reduce memory usage for fp8 scaled op. (#10531) 2025-10-29 15:43:51 -04:00
Alexander Piskun
6c14f3afac use new API client in Luma and Minimax nodes (#10528) 2025-10-29 11:14:56 -07:00
comfyanonymous
e525673f72 Fix issue. (#10527) 2025-10-29 00:37:00 -04:00
comfyanonymous
3fa7a5c04a Speed up offloading using pinned memory. (#10526)
To enable this feature use: --fast pinned_memory
2025-10-29 00:21:01 -04:00
Alexander Piskun
210f7a1ba5 convert nodes_recraft.py to V3 schema (#10507) 2025-10-28 14:38:05 -07:00
rattus
d202c2ba74 execution: Allow a subgraph nodes to execute multiple times (#10499)
In the case of --cache-none lazy and subgraph execution can cause
anything to be run multiple times per workflow. If that rerun nodes is
in itself a subgraph generator, this will crash for two reasons.

pending_subgraph_results[] does not cleanup entries after their use.
So when a pending_subgraph_result is consumed, remove it from the list
so that if the corresponding node is fully re-executed this misses
lookup and it fall through to execute the node as it should.

Secondly, theres is an explicit enforcement against dups in the
addition of subgraphs nodes as ephemerals to the dymprompt. Remove this
enforcement as the use case is now valid.
2025-10-28 16:22:08 -04:00
contentis
8817f8fc14 Mixed Precision Quantization System (#10498)
* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Fix missing keys

* Rename quant dtype parameter

* Rename quant dtype parameter

* Fix unittests for CPU build
2025-10-28 16:20:53 -04:00
comfyanonymous
22e40d2ace Tell users to update their nvidia drivers if portable doesn't start. (#10518) 2025-10-28 15:08:08 -04:00
comfyanonymous
3bea4efc6b Tell users to update nvidia drivers if problem with portable. (#10510) 2025-10-28 04:45:45 -04:00
comfyanonymous
8cf2ba4ba6 Remove comfy api key from queue api. (#10502) 2025-10-28 03:23:52 -04:00
comfyanonymous
b61a40cbc9 Bump stable portable to cu130 python 3.13.9 (#10508) 2025-10-28 03:21:45 -04:00
comfyanonymous
f2bb3230b7 ComfyUI version v0.3.67 2025-10-28 03:03:59 -04:00
Jedrzej Kosinski
614b8d3345 frontend bump to 1.28.8 (#10506) 2025-10-28 03:01:13 -04:00
ComfyUI Wiki
6abc30aae9 Update template to 0.2.4 (#10505) 2025-10-28 01:56:30 -04:00
Alexander Piskun
55bad30375 feat(api-nodes): add LTXV API nodes (#10496) 2025-10-27 22:25:29 -07:00
ComfyUI Wiki
c305deed56 Update template to 0.2.3 (#10503) 2025-10-27 22:24:16 -07:00
comfyanonymous
601ee1775a Add a bat to run comfyui portable without api nodes. (#10504) 2025-10-27 23:54:00 -04:00
comfyanonymous
c170fd2db5 Bump portable deps workflow to torch cu130 python 3.13.9 (#10493) 2025-10-26 20:23:01 -04:00
Alexander Piskun
9d529e5308 fix(api-nodes): random issues on Windows by capturing general OSError for retries (#10486) 2025-10-25 23:51:06 -07:00
comfyanonymous
f6bbc1ac84 Fix mistake. (#10484) 2025-10-25 23:07:29 -04:00
comfyanonymous
098a352f13 Add warning for torch-directml usage (#10482)
Added a warning message about the state of torch-directml.
2025-10-25 20:05:22 -04:00
Alexander Piskun
e86b79ab9e convert Gemini API nodes to V3 schema (#10476) 2025-10-25 14:35:30 -07:00
comfyanonymous
426cde37f1 Remove useless function (#10472) 2025-10-24 19:56:51 -04:00
Alexander Piskun
dd5af0c587 convert Tripo API nodes to V3 schema (#10469) 2025-10-24 15:48:34 -07:00
Alexander Piskun
388b306a2b feat(api-nodes): network client v2: async ops, cancellation, downloads, refactor (#10390)
* feat(api-nodes): implement new API client for V3 nodes

* feat(api-nodes): implement new API client for V3 nodes

* feat(api-nodes): implement new API client for V3 nodes

* converted WAN nodes to use new client; polishing

* fix(auth): do not leak authentification for the absolute urls

* convert BFL API nodes to use new API client; remove deprecated BFL nodes

* converted Google Veo nodes

* fix(Veo3.1 model): take into account "generate_audio" parameter
2025-10-23 22:37:16 -07:00
ComfyUI Wiki
24188b3141 Update template to 0.2.2 (#10461)
Fix template typo issue
2025-10-24 01:36:30 -04:00
comfyanonymous
1bcda6df98 WIP way to support multi multi dimensional latents. (#10456) 2025-10-23 21:21:14 -04:00
comfyanonymous
a1864c01f2 Small readme improvement. (#10442) 2025-10-22 17:26:22 -04:00
rattus
4739d7717f execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (Resubmit) (#10440)
* execution: fold in dependency aware caching

This makes --cache-none compatiable with lazy and expanded
subgraphs.

Currently the --cache-none option is powered by the
DependencyAwareCache. The cache attempts to maintain a parallel
copy of the execution list data structure, however it is only
setup once at the start of execution and does not get meaninigful
updates to the execution list.

This causes multiple problems when --cache-none is used with lazy
and expanded subgraphs as the DAC does not accurately update its
copy of the execution data structure.

DAC has an attempt to handle subgraphs ensure_subcache however
this does not accurately connect to nodes outside the subgraph.
The current semantics of DAC are to free a node ASAP after the
dependent nodes are executed.

This means that if a subgraph refs such a node it will be requed
and re-executed by the execution_list but DAC wont see it in
its to-free lists anymore and leak memory.

Rather than try and cover all the cases where the execution list
changes from inside the cache, move the while problem to the
executor which maintains an always up-to-date copy of the wanted
data-structure.

The executor now has a fast-moving run-local cache of its own.
Each _to node has its own mini cache, and the cache is unconditionally
primed at the time of add_strong_link.

add_strong_link is called for all of static workflows, lazy links
and expanded subgraphs so its the singular source of truth for
output dependendencies.

In the case of a cache-hit, the executor cache will hold the non-none
value (it will respect updates if they happen somehow as well).

In the case of a cache-miss, the executor caches a None and will
wait for a notification to update the value when the node completes.

When a node completes execution, it simply releases its mini-cache
and in turn its strong refs on its direct anscestor outputs, allowing
for ASAP freeing (same as the DependencyAwareCache but a little more
automatic).

This now allows for re-implementation of --cache-none with no cache
at all. The dependency aware cache was also observing the dependency
sematics for the objects and UI cache which is not accurate (this
entire logic was always outputs specific).

This also prepares for more complex caching strategies (such as RAM
pressure based caching), where a cache can implement any freeing
strategy completely independently of the DepedancyAwareness
requirement.

* main: re-implement --cache-none as no cache at all

The execution list now tracks the dependency aware caching more
correctly that the DependancyAwareCache.

Change it to a cache that does nothing.

* test_execution: add --cache-none to the test suite

--cache-none is now expected to work universally. Run it through the
full unit test suite. Propagate the server parameterization for whether
or not the server is capabale of caching, so that the minority of tests
that specifically check for cache hits can if else. Hard assert NOT
caching in the else to give some coverage of --cache-none expected
behaviour to not acutally cache.
2025-10-22 15:49:05 -04:00
Jedrzej Kosinski
f13cff0be6 Add custom node published subgraphs endpoint (#10438)
* Add get_subgraphs_dir to ComfyExtension and PUBLISHED_SUBGRAPH_DIRS to nodes.py

* Created initial endpoints, although the returned paths are a bit off currently

* Fix path and actually return real data

* Sanitize returned /api/global_subgraphs entries

* Remove leftover function from early prototyping

* Remove added whitespace

* Add None check for sanitize_entry
2025-10-21 23:16:16 -04:00
comfyanonymous
9cdc64998f Only disable cudnn on newer AMD GPUs. (#10437) 2025-10-21 19:15:23 -04:00
comfyanonymous
560b1bdfca ComfyUI version v0.3.66 2025-10-21 01:12:32 -04:00
comfyanonymous
b7992f871a Revert "execution: fold in dependency aware caching / Fix --cache-none with l…" (#10422)
This reverts commit b1467da480.
2025-10-20 19:03:06 -04:00
comfyanonymous
2c2aa409b0 Log message for cudnn disable on AMD. (#10418) 2025-10-20 15:43:24 -04:00
ComfyUI Wiki
a4787ac83b Update template to 0.2.1 (#10413)
* Update template to 0.1.97

* Update template to 0.2.1
2025-10-20 15:28:36 -04:00
Christian Byrne
b5c59b763c Deprecation warning on unused files (#10387)
* only warn for unused files

* include internal extensions
2025-10-19 13:05:46 -07:00
comfyanonymous
b4f30bd408 Pytorch is stupid. (#10398) 2025-10-19 01:25:35 -04:00
comfyanonymous
dad076aee6 Speed up chroma radiance. (#10395) 2025-10-18 23:19:52 -04:00
comfyanonymous
0cf33953a7 Fix batch size above 1 giving bad output in chroma radiance. (#10394) 2025-10-18 23:15:34 -04:00
comfyanonymous
5b80addafd Turn off cuda malloc by default when --fast autotune is turned on. (#10393) 2025-10-18 22:35:46 -04:00
comfyanonymous
9da397ea2f Disable torch compiler for cast_bias_weight function (#10384)
* Disable torch compiler for cast_bias_weight function

* Fix torch compile.
2025-10-17 20:03:28 -04:00
comfyanonymous
92d97380bd Update Python 3.14 installation instructions (#10385)
Removed mention of installing pytorch nightly for Python 3.14.
2025-10-17 18:22:59 -04:00
Alexander Piskun
99ce2a1f66 convert nodes_controlnet.py to V3 schema (#10202) 2025-10-17 14:13:05 -07:00
rattus128
b1467da480 execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (#10368)
* execution: fold in dependency aware caching

This makes --cache-none compatiable with lazy and expanded
subgraphs.

Currently the --cache-none option is powered by the
DependencyAwareCache. The cache attempts to maintain a parallel
copy of the execution list data structure, however it is only
setup once at the start of execution and does not get meaninigful
updates to the execution list.

This causes multiple problems when --cache-none is used with lazy
and expanded subgraphs as the DAC does not accurately update its
copy of the execution data structure.

DAC has an attempt to handle subgraphs ensure_subcache however
this does not accurately connect to nodes outside the subgraph.
The current semantics of DAC are to free a node ASAP after the
dependent nodes are executed.

This means that if a subgraph refs such a node it will be requed
and re-executed by the execution_list but DAC wont see it in
its to-free lists anymore and leak memory.

Rather than try and cover all the cases where the execution list
changes from inside the cache, move the while problem to the
executor which maintains an always up-to-date copy of the wanted
data-structure.

The executor now has a fast-moving run-local cache of its own.
Each _to node has its own mini cache, and the cache is unconditionally
primed at the time of add_strong_link.

add_strong_link is called for all of static workflows, lazy links
and expanded subgraphs so its the singular source of truth for
output dependendencies.

In the case of a cache-hit, the executor cache will hold the non-none
value (it will respect updates if they happen somehow as well).

In the case of a cache-miss, the executor caches a None and will
wait for a notification to update the value when the node completes.

When a node completes execution, it simply releases its mini-cache
and in turn its strong refs on its direct anscestor outputs, allowing
for ASAP freeing (same as the DependencyAwareCache but a little more
automatic).

This now allows for re-implementation of --cache-none with no cache
at all. The dependency aware cache was also observing the dependency
sematics for the objects and UI cache which is not accurate (this
entire logic was always outputs specific).

This also prepares for more complex caching strategies (such as RAM
pressure based caching), where a cache can implement any freeing
strategy completely independently of the DepedancyAwareness
requirement.

* main: re-implement --cache-none as no cache at all

The execution list now tracks the dependency aware caching more
correctly that the DependancyAwareCache.

Change it to a cache that does nothing.

* test_execution: add --cache-none to the test suite

--cache-none is now expected to work universally. Run it through the
full unit test suite. Propagate the server parameterization for whether
or not the server is capabale of caching, so that the minority of tests
that specifically check for cache hits can if else. Hard assert NOT
caching in the else to give some coverage of --cache-none expected
behaviour to not acutally cache.
2025-10-17 13:55:15 -07:00
Jedrzej Kosinski
d8d60b5609 Do batch_slice in EasyCache's apply_cache_diff (#10376) 2025-10-17 00:39:37 -04:00
comfyanonymous
b1293d50ef workaround also works on cudnn 91200 (#10375) 2025-10-16 19:59:56 -04:00
comfyanonymous
19b466160c Workaround for nvidia issue where VAE uses 3x more memory on torch 2.9 (#10373) 2025-10-16 18:16:03 -04:00
Alexander Piskun
bc0ad9bb49 fix(api-nodes): remove "veo2" model from Veo3 node (#10372) 2025-10-16 10:12:50 -07:00
Rizumu Ayaka
4054b4bf38 feat: deprecated API alert (#10366) 2025-10-16 01:13:31 -07:00
Arjan Singh
55ac7d333c Bump frontend to 1.28.7 (#10364) 2025-10-15 20:30:39 -07:00
Faych
afa8a24fe1 refactor: Replace manual patches merging with merge_nested_dicts (#10360) 2025-10-15 17:16:09 -07:00
Jedrzej Kosinski
493b81e48f Fix order of inputs nested merge_nested_dicts (#10362) 2025-10-15 16:47:26 -07:00
comfyanonymous
6b035bfce2 Latest pytorch stable is cu130 (#10361) 2025-10-15 18:48:12 -04:00
Alexander Piskun
74b7f0b04b feat(api-nodes): add Veo3.1 model (#10357) 2025-10-15 15:41:45 -07:00
chaObserv
f72c6616b2 Add TemporalScoreRescaling node (#10351)
* Add TemporalScoreRescaling node

* Mention image generation in tsr_k's tooltip
2025-10-15 18:12:25 -04:00
comfyanonymous
1c10b33f9b gfx942 doesn't support fp8 operations. (#10348) 2025-10-15 00:21:11 -04:00
Arjan Singh
ddfce1af4f Bump frontend to 1.28.6 (#10345) 2025-10-14 21:08:23 -04:00
Alexander Piskun
7a883849ea api-nodes: fixed dynamic pricing format; import comfy_io directly (#10336) 2025-10-13 23:55:56 -07:00
comfyanonymous
84867067ea Python 3.14 instructions. (#10337) 2025-10-14 02:09:12 -04:00
comfyanonymous
3374e900d0 Faster workflow cancelling. (#10301) 2025-10-13 23:43:53 -04:00
comfyanonymous
51696e3fdc ComfyUI version 0.3.65 2025-10-13 23:39:55 -04:00
comfyanonymous
dfff7e5332 Better memory estimation for the SD/Flux VAE on AMD. (#10334) 2025-10-13 22:37:19 -04:00
comfyanonymous
e4ea393666 Fix loading old stable diffusion ckpt files on newer numpy. (#10333) 2025-10-13 22:18:58 -04:00
comfyanonymous
c8674bc6e9 Enable RDNA4 pytorch attention on ROCm 7.0 and up. (#10332) 2025-10-13 21:19:03 -04:00
Alexander Piskun
3dfdcf66b6 convert nodes_hunyuan.py to V3 schema (#10136) 2025-10-13 12:36:26 -07:00
rattus128
95ca2e56c8 WAN2.2: Fix cache VRAM leak on error (#10308)
Same change pattern as 7e8dd275c2
applied to WAN2.2

If this suffers an exception (such as a VRAM oom) it will leave the
encode() and decode() methods which skips the cleanup of the WAN
feature cache. The comfy node cache then ultimately keeps a reference
this object which is in turn reffing large tensors from the failed
execution.

The feature cache is currently setup at a class variable on the
encoder/decoder however, the encode and decode functions always clear
it on both entry and exit of normal execution.

Its likely the design intent is this is usable as a streaming encoder
where the input comes in batches, however the functions as they are
today don't support that.

So simplify by bringing the cache back to local variable, so that if
it does VRAM OOM the cache itself is properly garbage when the
encode()/decode() functions dissappear from the stack.
2025-10-13 15:23:11 -04:00
Daniel Harte
27ffd12c45 add indent=4 kwarg to json.dumps() (#10307) 2025-10-13 12:14:52 -07:00
comfyanonymous
e693e4db6a Always set diffusion model to eval() mode. (#10331) 2025-10-13 14:57:27 -04:00
comfyanonymous
d68ece7301 Update the extra_model_paths.yaml.example (#10319) 2025-10-12 23:54:41 -04:00
Christian Byrne
894837de9a update extra models paths example (#10316) 2025-10-12 23:35:33 -04:00
ComfyUI Wiki
fdc92863b6 Update node docs to 0.3.0 (#10318) 2025-10-12 23:32:02 -04:00
comfyanonymous
a125cd84b0 Improve AMD performance. (#10302)
I honestly have no idea why this improves things but it does.
2025-10-12 00:28:01 -04:00
comfyanonymous
84e9ce32c6 Implement the mmaudio VAE. (#10300) 2025-10-11 22:57:23 -04:00
ComfyUI Wiki
f43b8ab2a2 Update template to 0.1.95 (#10294) 2025-10-11 10:27:22 -07:00
Alexander Piskun
14d642acd6 feat(api-nodes): add price extractor feature; small fixes to Kling & Pika nodes (#10284) 2025-10-10 16:21:40 -07:00
Alexander Piskun
aa895db7e8 feat(GeminiImage-ApiNode): add aspect_ratio and release version of model (#10255) 2025-10-10 16:17:20 -07:00
comfyanonymous
cdfc25a160 Fix save audio nodes saving mono audio as stereo. (#10289) 2025-10-10 17:33:51 -04:00
Alexander Piskun
81e4dac107 convert nodes_upscale_model.py to V3 schema (#10149) 2025-10-09 16:08:40 -07:00
Alexander Piskun
90853fb9cd convert nodes_flux to V3 schema (#10122) 2025-10-09 16:07:17 -07:00
comfyanonymous
f1dd6e50f8 Fix bug with applying loras on fp8 scaled without fp8 ops. (#10279) 2025-10-09 19:02:40 -04:00
Alexander Piskun
fc0fbf141c convert nodes_sd3.py and nodes_slg.py to V3 schema (#10162) 2025-10-09 15:18:23 -07:00
Alexander Piskun
f3d5d328a3 fix(v3,api-nodes): V3 schema typing; corrected Pika API nodes (#10265) 2025-10-09 15:15:03 -07:00
comfyanonymous
139addd53c More surgical fix for #10267 (#10276) 2025-10-09 16:37:35 -04:00
Alexander Piskun
cbee7d3390 convert nodes_latent.py to V3 schema (#10160) 2025-10-08 23:14:00 -07:00
Alexander Piskun
6732014a0a convert nodes_compositing.py to V3 schema (#10174) 2025-10-08 23:13:15 -07:00
Alexander Piskun
989f715d92 convert nodes_lora_extract.py to V3 schema (#10182) 2025-10-08 23:11:45 -07:00
Alexander Piskun
2ba8d7cce8 convert nodes_model_downscale.py to V3 schema (#10199) 2025-10-08 23:10:23 -07:00
Alexander Piskun
51fb505ffa feat(api-nodes, pylint): use lazy formatting in logging functions (#10248) 2025-10-08 23:06:56 -07:00
Jedrzej Kosinski
72c2071972 Mvly/node update (#10042)
* updated V2V node to allow for control image input
exposing steps in v2v
fixing guidance_scale as input parameter

TODO: allow for motion_intensity as input param.

* refactor: comment out unsupported resolution and adjust default values in video nodes

* set control_after_generate

* adding new defaults

* fixes

* changed control_after_generate back to True

* changed control_after_generate back to False

---------

Co-authored-by: thorsten <thorsten@tripod-digital.co.nz>
2025-10-08 20:30:41 -04:00
comfyanonymous
6e59934089 Refactor model sampling sigmas code. (#10250) 2025-10-08 17:49:02 -04:00
Alexander Piskun
3e0eb8d33f feat(V3-io): allow Enum classes for Combo options (#10237) 2025-10-08 00:14:04 -07:00
164 changed files with 20110 additions and 11766 deletions

View File

@@ -1,5 +1,5 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
As of the time of writing this you need this driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
HOW TO RUN:
@@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@@ -0,0 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -8,13 +8,15 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument.
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:

View File

@@ -0,0 +1,21 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

58
.github/workflows/api-node-template.yml vendored Normal file
View File

@@ -0,0 +1,58 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@@ -14,13 +14,13 @@ jobs:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
name: "Release NVIDIA Default (cu130)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
cache_tag: "cu130"
python_minor: "13"
python_patch: "6"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
@@ -43,16 +43,33 @@ jobs:
test_release: true
secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu126"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu126"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu126"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
name: "Release AMD ROCm 7.1.1"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
cache_tag: "rocm711"
python_minor: "12"
python_patch: "10"
rel_name: "amd"

View File

@@ -21,14 +21,15 @@ jobs:
fail-fast: false
matrix:
# os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
# os: [macos, linux]
os: [linux]
python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
@@ -73,14 +74,15 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
# os: [macos, linux]
os: [linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""

View File

@@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "130"
python_minor:
description: 'python minor version'
@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "6"
default: "9"
# push:
# branches:
# - master

168
QUANTIZATION.md Normal file
View File

@@ -0,0 +1,168 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@@ -112,10 +114,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
@@ -172,15 +175,19 @@ There is a portable standalone build for Windows that should work for running on
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
@@ -197,7 +204,11 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
### Instructions:
Git clone this repo.
@@ -214,7 +225,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
@@ -235,7 +246,7 @@ RDNA 4 (RX 9000 series):
### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command:
@@ -245,15 +256,11 @@ This is the command to install the Pytorch xpu nightly which might have some per
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
This is the command to install pytorch nightly instead which might have performance improvements.

View File

@@ -10,7 +10,8 @@ import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict, Optional
from typing import Dict, TypedDict, Optional
from aiohttp import web
from importlib.metadata import version
import requests
@@ -257,7 +258,54 @@ comfyui-frontend-package is not installed.
sys.exit(-1)
@classmethod
def templates_path(cls) -> str:
def template_asset_map(cls) -> Optional[Dict[str, str]]:
"""Return a mapping of template asset names to their absolute paths."""
try:
from comfyui_workflow_templates import (
get_asset_path,
iter_templates,
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
try:
template_entries = list(iter_templates())
except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}")
return None
asset_map: Dict[str, str] = {}
try:
for entry in template_entries:
for asset in entry.assets:
asset_map[asset.filename] = get_asset_path(
entry.template_id, asset.filename
)
except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}")
return None
if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?")
return None
return asset_map
@classmethod
def legacy_templates_path(cls) -> Optional[str]:
"""Return the legacy templates directory shipped inside the meta package."""
try:
import comfyui_workflow_templates
@@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed.
********** ERROR ***********
""".strip()
)
return None
@classmethod
def embedded_docs_path(cls) -> str:
@@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed.
logging.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):
assets = cls.template_asset_map()
if not assets:
return None
async def serve_template(request: web.Request) -> web.StreamResponse:
rel_path = request.match_info.get("path", "")
target = assets.get(rel_path)
if target is None:
raise web.HTTPNotFound()
return web.FileResponse(target)
return serve_template

112
app/subgraph_manager.py Normal file
View File

@@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@@ -59,6 +59,9 @@ class UserManager():
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
# Block System Users (use same error message to prevent probing)
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise KeyError("Unknown user: " + user)
if user not in self.users:
raise KeyError("Unknown user: " + user)
@@ -66,15 +69,16 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
if type == "userdata":
root_dir = user_directory
root_dir = folder_paths.get_user_directory()
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
user_root = folder_paths.get_public_user_directory(user)
if user_root is None:
return None
path = user_root
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
@@ -101,7 +105,11 @@ class UserManager():
name = name.strip()
if not name:
raise ValueError("username not provided")
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise ValueError("System User prefix not allowed")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise ValueError("System User prefix not allowed")
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
@@ -132,7 +140,10 @@ class UserManager():
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
try:
user_id = self.add_user(username)
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
return web.json_response(user_id)
@routes.get("/userdata")
@@ -424,7 +435,7 @@ class UserManager():
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
if not isinstance(dest, str):
return dest
overwrite = request.query.get("overwrite", 'true') != "false"

View File

@@ -413,7 +413,8 @@ class ControlNet(nn.Module):
out_middle = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
if y is None:
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
emb = emb + self.label_emb(y)
h = x

View File

@@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -130,7 +131,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -145,7 +147,9 @@ class PerformanceFeature(enum.Enum):
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
@@ -157,7 +161,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

View File

@@ -310,11 +310,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@@ -350,12 +352,13 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@@ -6,6 +6,7 @@ class LatentFormat:
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
def process_in(self, latent):
@@ -178,6 +179,54 @@ class Flux(SD3):
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class Flux2(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors =[
[0.0058, 0.0113, 0.0073],
[0.0495, 0.0443, 0.0836],
[-0.0099, 0.0096, 0.0644],
[0.2144, 0.3009, 0.3652],
[0.0166, -0.0039, -0.0054],
[0.0157, 0.0103, -0.0160],
[-0.0398, 0.0902, -0.0235],
[-0.0052, 0.0095, 0.0109],
[-0.3527, -0.2712, -0.1666],
[-0.0301, -0.0356, -0.0180],
[-0.0107, 0.0078, 0.0013],
[0.0746, 0.0090, -0.0941],
[0.0156, 0.0169, 0.0070],
[-0.0034, -0.0040, -0.0114],
[0.0032, 0.0181, 0.0080],
[-0.0939, -0.0008, 0.0186],
[0.0018, 0.0043, 0.0104],
[0.0284, 0.0056, -0.0127],
[-0.0024, -0.0022, -0.0030],
[0.1207, -0.0026, 0.0065],
[0.0128, 0.0101, 0.0142],
[0.0137, -0.0072, -0.0007],
[0.0095, 0.0092, -0.0059],
[0.0000, -0.0077, -0.0049],
[-0.0465, -0.0204, -0.0312],
[0.0095, 0.0012, -0.0066],
[0.0290, -0.0034, 0.0025],
[0.0220, 0.0169, -0.0048],
[-0.0332, -0.0457, -0.0468],
[-0.0085, 0.0389, 0.0609],
[-0.0076, 0.0003, -0.0043],
[-0.0111, -0.0460, -0.0614],
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
@@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
taesd_decoder_name = "taehv"
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
@@ -445,7 +495,7 @@ class Wan21(LatentFormat):
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
self.taesd_decoder_name = "lighttaew2_1"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
@@ -516,6 +566,7 @@ class Wan22(Wan21):
def __init__(self):
self.scale_factor = 1.0
self.taesd_decoder_name = "lighttaew2_2"
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
@@ -611,6 +662,67 @@ class HunyuanImage21Refiner(LatentFormat):
latent_dimensions = 3
scale_factor = 1.03682
def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z
class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[ 0.0568, -0.0521, -0.0131],
[ 0.0014, 0.0735, 0.0326],
[ 0.0186, 0.0531, -0.0138],
[-0.0031, 0.0051, 0.0288],
[ 0.0110, 0.0556, 0.0432],
[-0.0041, -0.0023, -0.0485],
[ 0.0530, 0.0413, 0.0253],
[ 0.0283, 0.0251, 0.0339],
[ 0.0277, -0.0372, -0.0093],
[ 0.0393, 0.0944, 0.1131],
[ 0.0020, 0.0251, 0.0037],
[-0.0017, 0.0012, 0.0234],
[ 0.0468, 0.0436, 0.0203],
[ 0.0354, 0.0439, -0.0233],
[ 0.0090, 0.0123, 0.0346],
[ 0.0382, 0.0029, 0.0217],
[ 0.0261, -0.0300, 0.0030],
[-0.0088, -0.0220, -0.0283],
[-0.0272, -0.0121, -0.0363],
[-0.0664, -0.0622, 0.0144],
[ 0.0414, 0.0479, 0.0529],
[ 0.0355, 0.0612, -0.0247],
[ 0.0147, 0.0264, 0.0174],
[ 0.0438, 0.0038, 0.0542],
[ 0.0431, -0.0573, -0.0033],
[-0.0162, -0.0211, -0.0406],
[-0.0487, -0.0295, -0.0393],
[ 0.0005, -0.0109, 0.0253],
[ 0.0296, 0.0591, 0.0353],
[ 0.0119, 0.0181, -0.0306],
[-0.0085, -0.0362, 0.0229],
[ 0.0005, -0.0106, 0.0242]
]
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@@ -1,15 +1,15 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
# TODO: remove this in a few months
SingleStreamBlock = None
DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut):
@@ -48,124 +48,6 @@ class Approximator(nn.Module):
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()

View File

@@ -11,12 +11,12 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
DoubleStreamBlock,
SingleStreamBlock,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@@ -90,6 +90,7 @@ class Chroma(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -98,7 +99,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@@ -178,7 +179,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
@@ -221,7 +225,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:

View File

@@ -10,12 +10,10 @@ from torch import Tensor, nn
from einops import repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator,
)
from .layers import (
@@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
modulation=False,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)
@@ -189,15 +188,15 @@ class ChromaRadiance(Chroma):
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
else:
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
@@ -240,17 +239,8 @@ class ChromaRadiance(Chroma):
end = min(i + tile_size, num_patches)
# Slice the current tile from the input tensors
nerf_hidden_tile = nerf_hidden[:, i:end, :]
nerf_pixels_tile = nerf_pixels[:, i:end, :]
# Get the actual number of patches in this tile (can be smaller for the last tile)
num_patches_tile = nerf_hidden_tile.shape[1]
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)

View File

@@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
@@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module):
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
@dataclass
@@ -98,11 +98,11 @@ class ModulationOut:
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
if vec.ndim == 2:
@@ -129,77 +129,129 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.modulation = modulation
if self.modulation:
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if mlp_silu_act:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
del img_attn
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
@@ -220,6 +272,9 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
modulation=True,
mlp_silu_act=False,
bias=True,
dtype=None,
device=None,
operations=None
@@ -231,30 +286,47 @@ class SingleStreamBlock(nn.Module):
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else:
self.modulation = None
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
if self.modulation:
mod, _ = self.modulation(vec)
else:
mod = vec
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
@@ -262,11 +334,11 @@ class SingleStreamBlock(nn.Module):
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2:

View File

@@ -7,15 +7,8 @@ import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x

View File

@@ -15,6 +15,7 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
Modulation
)
@dataclass
@@ -33,6 +34,11 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
class Flux(nn.Module):
@@ -58,13 +64,17 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
@@ -73,6 +83,9 @@ class Flux(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -81,13 +94,30 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.global_modulation:
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.single_stream_modulation = Modulation(
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
)
def forward_orig(
self,
@@ -103,9 +133,6 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -118,9 +145,17 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.vector_in is not None:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
@@ -136,7 +171,10 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -177,7 +215,13 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1)
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -207,10 +251,10 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0):
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -222,10 +266,22 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
steps_h = h_len
steps_w = w_len
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@@ -241,16 +297,16 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", "offset")
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
for ref in ref_latents:
if ref_latents_method == "index":
index += 1
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
elif ref_latents_method == "uxo":
@@ -274,7 +330,11 @@ class Flux(nn.Module):
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
@@ -42,6 +41,8 @@ class HunyuanVideoParams:
guidance_embed: bool
byt5: bool
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
class SelfAttentionRef(nn.Module):
@@ -157,7 +158,10 @@ class TokenRefiner(nn.Module):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
if x.dtype == torch.float16:
c = x.float().sum(dim=1) / x.shape[1]
else:
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
@@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
self.use_cond_type_embedding = params.use_cond_type_embedding
self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -266,6 +274,18 @@ class HunyuanVideo(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
# HunyuanVideo 1.5 specific modules
if self.vision_in_dim is not None:
from comfy.ldm.wan.model import MLPProj
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
else:
self.vision_in = None
if self.use_cond_type_embedding:
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
else:
self.cond_type_embedding = None
def forward_orig(
self,
img: Tensor,
@@ -276,6 +296,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor,
y: Tensor = None,
txt_byt5=None,
clip_fea=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
@@ -331,12 +352,31 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
if self.cond_type_embedding is not None:
self.cond_type_embedding.to(txt.device)
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
txt = txt + cond_emb.to(txt.dtype)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
else:
txt = torch.cat((txt, txt_byt5), dim=1)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
if clip_fea is not None:
txt_vision_states = self.vision_in(clip_fea)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -349,7 +389,10 @@ class HunyuanVideo(nn.Module):
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -371,7 +414,10 @@ class HunyuanVideo(nn.Module):
img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -430,14 +476,14 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
@@ -445,5 +491,5 @@ class HunyuanVideo(nn.Module):
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out

View File

@@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class SRModel3DV2(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y
class Upsampler(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
self.up = nn.ModuleList()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_shortcut=False,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
def forward(self, z):
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# upsampling
for stage in self.up:
for blk in stage.block:
x = blk(x)
out = self.conv_out(F.silu(self.norm_out(x)))
return out
UPSAMPLERS = {
"720p": SRModel3DV2,
"1080p": Upsampler,
}
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()
def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@@ -4,8 +4,40 @@ import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -14,7 +46,7 @@ class RMS_norm(nn.Module):
self.gamma = nn.Parameter(torch.empty(shape))
def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
@@ -27,11 +59,12 @@ class DnSmpl(nn.Module):
self.tds = tds
self.gs = fct * ic // oc
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tds else 1
h = self.conv(x)
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tds and self.refiner_vae and conv_carry_in is None:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@@ -39,14 +72,7 @@ class DnSmpl(nn.Module):
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1)
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
h = h[:, :, 1:, :, :]
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
@@ -54,34 +80,32 @@ class DnSmpl(nn.Module):
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
x = x[:, :, 1:, :, :]
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
if h.shape[2] == 0:
return hf + xf
b, ci, frms, ht, wd = x.shape
nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
return h + sc
b, ci, frms, ht, wd = x.shape
nf = frms // r1
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = x.shape
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
if self.tds and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class UpSmpl(nn.Module):
@@ -94,11 +118,11 @@ class UpSmpl(nn.Module):
self.tus = tus
self.rp = fct * oc // ic
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tus else 1
h = self.conv(x)
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tus and self.refiner_vae:
if self.tus and self.refiner_vae and conv_carry_in is None:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@@ -107,14 +131,7 @@ class UpSmpl(nn.Module):
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2)
h = h[:, :, 1:, :, :]
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
@@ -125,29 +142,43 @@ class UpSmpl(nn.Module):
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
x = x[:, :, 1:, :, :]
sc = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
return h + sc
x = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = x.shape
nc = c // (r1 * 2 * 2)
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
if self.tus and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
@@ -160,7 +191,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
conv_op = NoPadConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -175,10 +206,9 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
@@ -188,9 +218,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@@ -201,31 +231,50 @@ class Encoder(nn.Module):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
if self.refiner_vae:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.ffactor_temporal:
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
x = xl
else:
x = [x]
out = []
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
conv_carry_in = None
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
@@ -239,7 +288,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
conv_op = NoPadConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -249,9 +298,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -259,10 +308,9 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
@@ -275,27 +323,41 @@ class Decoder(nn.Module):
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
if self.refiner_vae:
x = torch.split(x, 2, dim=2)
else:
x = [ x ]
out = []
out = self.conv_out(F.silu(self.norm_out(x)))
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
x1 = [ F.silu(self.norm_out(x1)) ]
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
del x
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@@ -3,12 +3,11 @@ from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -238,20 +237,6 @@ class FeedForward(nn.Module):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
@@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
return x
@@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
dtype = torch.float32
device = indices_grid.device
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
start = 1
end = theta
device = fractional_positions.device
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
# Pad if dim is not divisible by 6
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis
class LTXVModel(torch.nn.Module):
@@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)
x = self.patchifier.unpatchify(

View File

@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
@@ -31,6 +32,7 @@ class JointAttention(nn.Module):
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
out_bias: bool = False,
operation_settings={},
):
"""
@@ -59,7 +61,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
bias=False,
bias=out_bias,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
@@ -70,35 +72,6 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward(
self,
x: torch.Tensor,
@@ -134,8 +107,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
@@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float,
qk_norm: bool,
modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={},
) -> None:
"""
@@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
hidden_dim=dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
@@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
if z_image_modulation:
self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear(
min(dim, 256),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(
self,
@@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT.
"""
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
@@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"),
)
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(hidden_size, 1024),
min(hidden_size, min_mod),
hidden_size,
bias=True,
device=operation_settings.get("device"),
@@ -373,12 +363,16 @@ class NextDiT(nn.Module):
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None,
device=None,
dtype=None,
@@ -390,6 +384,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
@@ -411,6 +407,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
@@ -434,7 +431,7 @@ class NextDiT(nn.Module):
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
@@ -457,18 +454,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier,
norm_eps,
qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
@@ -503,96 +506,54 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
dtype = x[0].dtype
if cap_mask is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
else:
l_effective_cap_len = [num_tokens] * bsz
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
if cap_mask is not None and not torch.is_floating_point(cap_mask):
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
padded_img_mask = None
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@@ -615,7 +576,7 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features
"""
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

View File

View File

@@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x

View File

@@ -0,0 +1,156 @@
from typing import Literal
import torch
import torch.nn as nn
from .distributions import DiagonalGaussianDistribution
from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.norm_fn = norm_fn
# mel = librosa_mel_fn(sr=self.sampling_rate,
# n_fft=self.n_fft,
# n_mels=self.num_mels,
# fmin=self.fmin,
# fmax=self.fmax)
# mel_basis = torch.from_numpy(mel).float()
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
hann_window = torch.hann_window(self.win_size)
self.register_buffer('mel_basis', mel_basis)
self.register_buffer('hann_window', hann_window)
@property
def device(self):
return self.mel_basis.device
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1),
[int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2)],
mode='reflect')
waveform = waveform.squeeze(1)
spec = torch.stft(waveform,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec, self.norm_fn)
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
bigvgan_config = {
"resblock": "1",
"num_mels": 80,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5],
],
"activation": "snakebeta",
"snake_logscale": True,
}
self.vocoder = BigVGANVocoder(
bigvgan_config
).eval()
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
# x: (B * L)
mel = self.mel_converter(x)
dist = self.vae.encode(mel)
return dist
@torch.no_grad()
def decode(self, z):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)
return dist.mean

View File

@@ -0,0 +1,219 @@
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))
])
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))
])
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
class BigVGANVocoder(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super().__init__()
if isinstance(h, dict):
h = SimpleNamespace(**h)
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@@ -0,0 +1,358 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
else:
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.optimized_attention = vae_attention()
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=1).unbind(2)
h = self.optimized_attention(q, k, v)
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x

View File

@@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma
import comfy.ops
from einops import rearrange
import comfy.model_management
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False):
@@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ddconfig.get("batch_norm_latent", False):
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
self.bn.eval()
else:
self.bn = None
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
@@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if self.bn is not None:
z = rearrange(z,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = torch.nn.functional.batch_norm(z,
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
momentum=self.bn_momentum,
eps=self.bn_eps)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.bn is not None:
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
z = z * s + m
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)

View File

@@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size

View File

@@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)

View File

@@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -134,33 +135,34 @@ class Attention(nn.Module):
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -234,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
@@ -246,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@@ -413,7 +419,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
@@ -433,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32
# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
@@ -588,7 +589,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@@ -601,10 +602,22 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@@ -630,7 +643,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):

View File

@@ -657,51 +657,51 @@ class WanVAE(nn.Module):
)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.encoder)
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
@@ -715,12 +715,3 @@ class WanVAE(nn.Module):
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -313,6 +313,15 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
return key_map

View File

@@ -134,10 +134,11 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
@@ -196,8 +197,14 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs):
return timestep
@@ -326,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@@ -669,7 +684,6 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -698,7 +712,6 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -885,12 +898,13 @@ class Flux(BaseModel):
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
mask_ref_size = kwargs.get("attention_mask_img_shape", None)
if mask_ref_size is not None:
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
@@ -912,9 +926,19 @@ class Flux(BaseModel):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class Flux2(Flux):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
target_text_len = 512
if cross_attn.shape[1] < target_text_len:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@@ -1090,9 +1114,13 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
return out
class WAN21(BaseModel):
@@ -1523,3 +1551,94 @@ class HunyuanImage21Refiner(HunyuanImage21):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(True)
return out
class HunyuanVideo15(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
return out
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
image = utils.resize_to_batch_size(image, noise.shape[0])
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
else:
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out

View File

@@ -6,6 +6,20 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@@ -172,30 +186,68 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
# HunyuanVideo 1.5
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["use_cond_type_embedding"] = True
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
else:
dit_config["vision_in_dim"] = None
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
dit_config["axes_dim"] = [32, 32, 32, 32]
dit_config["num_heads"] = 48
dit_config["mlp_ratio"] = 3.0
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
patch_size = 1
else:
dit_config["image_model"] = "flux"
dit_config["axes_dim"] = [16, 56, 56]
dit_config["num_heads"] = 24
dit_config["mlp_ratio"] = 4.0
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
patch_size = 2
dit_config["in_channels"] = 16
patch_size = 2
dit_config["hidden_size"] = 3072
dit_config["context_in_dim"] = 4096
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
w = state_dict[in_key]
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
@@ -213,7 +265,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 32
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else:
@@ -364,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@@ -701,6 +770,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config
def unet_prefix_from_state_dict(state_dict):

View File

@@ -89,6 +89,7 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml
directml_enabled = True
device_index = args.directml
@@ -330,13 +331,21 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@@ -344,11 +353,11 @@ try:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True
except:
@@ -370,6 +379,9 @@ try:
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
try:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
@@ -492,6 +504,7 @@ class LoadedModel:
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@@ -676,8 +689,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@@ -925,11 +941,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
if d == torch.bfloat16 and should_use_bf16(device):
return d
return torch.float32
@@ -991,12 +1003,6 @@ def device_supports_non_blocking(device):
return False
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
@@ -1006,54 +1012,72 @@ def force_channels_last():
STREAMS = {}
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
NUM_STREAMS = 0
if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia
if is_nvidia():
NUM_STREAMS = 2
if args.disable_async_offload:
NUM_STREAMS = 0
if NUM_STREAMS > 0:
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1:
if NUM_STREAMS == 0:
return None
if torch.compiler.is_compiling():
return None
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
s1 = torch.cuda.Stream(device=device, priority=0)
s1.as_context = torch.cuda.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
s1 = torch.xpu.Stream(device=device, priority=0)
s1.as_context = torch.xpu.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
if stream is None or current_stream(device) is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1061,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
@@ -1078,6 +1109,83 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False
if not is_device_cpu(tensor.device):
return False
if tensor.is_pinned():
#NOTE: Cuda does detect when a tensor is already pinned and would
#error below, but there are proven cases where this also queues an error
#on the GPU async. So dont trust the CUDA API and guard here
return False
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
return False
if size != size_stored:
logging.warning("Size of pinned tensor changed")
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention
@@ -1330,7 +1438,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False

View File

@@ -123,16 +123,39 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
def __init__(self, key, patches):
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@@ -217,13 +240,13 @@ class ModelPatcher:
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -255,12 +278,18 @@ class ModelPatcher:
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
def model_size(self):
if self.size > 0:
return self.size
self.size = comfy.model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -268,7 +297,7 @@ class ModelPatcher:
return self.model.lowvram_patch_counter
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -280,6 +309,7 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@@ -436,6 +466,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -604,6 +647,21 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -616,7 +674,16 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -625,25 +692,30 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
load_completely = []
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
if not lowvram_fits:
offload_buffer = potential_offload
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@@ -657,23 +729,28 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
if full_load or lowvram_fits:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
else:
offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
@@ -697,7 +774,9 @@ class ModelPatcher:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -705,11 +784,17 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -718,6 +803,7 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@@ -746,6 +832,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@@ -770,6 +857,7 @@ class ModelPatcher:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
@@ -781,20 +869,21 @@ class ModelPatcher:
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
for unload in unload_list:
if memory_to_free < memory_freed:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
params = unload[3]
module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -825,11 +914,19 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
patch_counter += 1
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True
if cast_weight:
@@ -837,11 +934,18 @@ class ModelPatcher:
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@@ -854,6 +958,9 @@ class ModelPatcher:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True)
@@ -1241,5 +1348,6 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@@ -58,15 +64,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma)
class X0(EPS):
@@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma
noise += latent_image
return noise

91
comfy/nested_tensor.py Normal file
View File

@@ -0,0 +1,91 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@@ -24,13 +24,18 @@ import comfy.float
import comfy.rmsnorm
import contextlib
def run_every_op():
if torch.compiler.is_compiling():
return
comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
if torch.cuda.is_available():
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@@ -50,49 +55,94 @@ try:
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if comfy.model_management.is_nvidia():
cudnn_version = torch.backends.cudnn.version()
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.")
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
dtype = input.dtype
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
offload_stream = comfy.model_management.get_offload_stream(device)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
comfy_cast_weights = False
@@ -105,10 +155,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -119,10 +172,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -133,10 +189,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -146,11 +205,23 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
return out
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -161,10 +232,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -176,13 +250,17 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -195,13 +273,18 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -217,12 +300,15 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -238,12 +324,15 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -258,10 +347,14 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -312,20 +405,18 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight
scale_input = self.scale_input
@@ -337,23 +428,20 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
uncast_bias_weight(self, w, bias, offload_stream)
return o
return None
@@ -373,8 +461,10 @@ class fp8_ops(manual_cast):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -402,22 +492,26 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
@@ -444,8 +538,142 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_layer_quant_config = layer_quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(value, curr_value)
merged_dict[key] = merge_nested_dicts(curr_value, value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:

573
comfy/quant_ops.py Normal file
View File

@@ -0,0 +1,573 @@
import torch
import logging
from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@@ -4,13 +4,9 @@ import comfy.samplers
import comfy.utils
import numpy as np
import logging
import comfy.nested_tensor
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = comfy.nested_tensor.NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
return noises
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

View File

@@ -306,17 +306,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
copy_dict1=False)
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
@@ -789,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@@ -799,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
#make sure each cond area has an opposite one with the same area
for k in conds:
@@ -969,11 +962,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@@ -987,7 +980,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
@@ -1001,7 +994,7 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
@@ -1014,6 +1007,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
noise, _ = comfy.utils.pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@@ -1033,7 +1032,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
@@ -1041,6 +1040,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
if len(latent_shapes) > 1:
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output

View File

@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import yaml
import math
@@ -51,6 +52,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.model_patcher
import comfy.lora
@@ -58,6 +60,8 @@ import comfy.lora_convert
import comfy.hooks
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
import comfy.taesd.taehv
import comfy.latent_formats
import comfy.ldm.flux.redux
@@ -142,6 +146,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -275,8 +282,13 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
else:
VAE_KL_MEM_RATIO = 1.0
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
@@ -287,10 +299,12 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
self.crop_input = True
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -345,7 +359,7 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@@ -371,6 +385,17 @@ class VAE:
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
if 'bn.running_mean' in sd:
ddconfig["batch_norm_latent"] = True
self.downscale_ratio *= 2
self.upscale_ratio *= 2
self.latent_channels *= 4
old_memory_used_decode = self.memory_used_decode
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
@@ -430,20 +455,20 @@ class VAE:
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.latent_channels = 64
self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.not_video = False
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -485,13 +510,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
dim = sd["decoder.head.0.gamma"].shape[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
@@ -542,6 +568,54 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:
mode = '16k'
else:
mode = '44k'
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
self.latent_channels = 20
self.output_channels = 2
self.upscale_ratio = 512 * (44100 / sample_rate)
self.downscale_ratio = 512 * (44100 / sample_rate)
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
self.latent_channels = sd["decoder.1.weight"].shape[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels == 48: # Wan 2.2
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
self.process_input = self.process_output = lambda image: image
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -569,12 +643,25 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
@@ -868,12 +955,18 @@ class CLIPType(Enum):
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@@ -891,6 +984,10 @@ class TEModel(Enum):
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -923,6 +1020,15 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
return TEModel.LLAMA3_8
return None
@@ -1037,6 +1143,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1083,6 +1196,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -1095,6 +1211,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0
for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
@@ -1233,7 +1351,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1267,7 +1385,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
@@ -1301,7 +1419,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
@@ -1317,8 +1438,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
@@ -109,13 +108,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)
if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
layer_quant_config = None
if quantization_metadata is not None:
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else:
operations = comfy.ops.manual_cast
# Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
@@ -154,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all":
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
@@ -256,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks:
attention_mask_model = attention_mask
if self.layer == "all":
if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
@@ -460,7 +471,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@@ -468,6 +479,7 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@@ -522,6 +534,12 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
for i in range(amount):
tokens.insert(0, (self.pad_token, 1.0, 0))
else:
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
@@ -600,7 +618,7 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
self.pad_tokens(batch, remaining_length)
#start new batch
batch = []
if self.start_token is not None:
@@ -614,11 +632,11 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
self.pad_tokens(batch, min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
self.pad_tokens(batch, self.max_length - len(batch))
if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
self.pad_tokens(batch, min_length - len(batch))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

View File

@@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
from . import supported_models_base
from . import latent_formats
@@ -741,6 +742,37 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
class Flux2(Flux):
unet_config = {
"image_model": "flux2",
}
sampling_settings = {
"shift": 2.02,
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@@ -963,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0,
}
memory_usage_factor = 1.2
memory_usage_factor = 1.4
unet_extra_config = {}
latent_format = latent_formats.Flux
@@ -982,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class ZImage(Lumina2):
unet_config = {
"image_model": "lumina2",
"dim": 3840,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
memory_usage_factor = 1.7
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@@ -1374,6 +1424,55 @@ class HunyuanImage21Refiner(HunyuanVideo):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
class HunyuanVideo15(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
}
sampling_settings = {
"shift": 7.0,
}
memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
"in_channels": 98,
}
sampling_settings = {
"shift": 2.0,
}
memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
models += [SVD_img2vid]

View File

@@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod

171
comfy/taesd/taehv.py Normal file
View File

@@ -0,0 +1,171 @@
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple, deque
import comfy.ops
operations=comfy.ops.disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out, act_func):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = act_func
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
B, T, C, H, W = x.shape
if parallel:
x = x.reshape(B*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
BT, C, H, W = x.shape
T = BT // B
_x = x.reshape(B, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
BT, C, H, W = x.shape
T = BT // B
x = x.view(B, T, C, H, W)
else:
out = []
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.popleft()
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
del xt
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt.detach().clone()
else:
xt_new = b(xt, mem[i])
mem[i] = xt.detach().clone()
del xt
work_queue.appendleft(TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt.detach().clone())
if len(mem[i]) == b.stride:
B, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
mem[i] = []
work_queue.appendleft(TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
work_queue.appendleft(TWorkItem(xt_next, i+1))
del xt
else:
xt = b(xt)
work_queue.appendleft(TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
super().__init__()
self.image_channels = 3
self.patch_size = 1
self.latent_channels = latent_channels
self.parallel = parallel
self.latent_format = latent_format
self.show_progress_bar = show_progress_bar
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
act_func = nn.ReLU(inplace=True)
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@@ -1,10 +1,13 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.llama
import comfy.model_management
from transformers import T5TokenizerFast
from transformers import T5TokenizerFast, LlamaTokenizerFast
import torch
import os
import json
import base64
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -68,3 +71,106 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
def load_mistral_tokenizer(data):
if torch.is_tensor(data):
data = data.numpy().tobytes()
try:
from transformers.integrations.mistral import MistralConverter
except ModuleNotFoundError:
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
mistral_vocab = json.loads(data)
special_tokens = {}
vocab = {}
max_vocab = mistral_vocab["config"]["default_vocab_size"]
max_vocab -= len(mistral_vocab["special_tokens"])
for w in mistral_vocab["vocab"]:
r = w["rank"]
if r >= max_vocab:
continue
vocab[base64.b64decode(w["token_bytes"])] = r
for w in mistral_vocab["special_tokens"]:
if "token_bytes" in w:
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
else:
special_tokens[w["token_str"]] = w["rank"]
all_special = []
for v in special_tokens:
all_special.append(v)
special_tokens.update(vocab)
vocab = special_tokens
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
class MistralTokenizerClass:
@staticmethod
def from_pretrained(path, **kwargs):
return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
textmodel_json_config["num_hidden_layers"] = num_layers
if num_layers < 40:
textmodel_json_config["final_norm"] = False
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Flux2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
model_options["num_layers"] = 30
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Flux2TEModel_

View File

@@ -1,6 +1,7 @@
from comfy import sd1_clip
import comfy.model_management
import comfy.text_encoders.llama
from .hunyuan_image import HunyuanImageTokenizer
from transformers import LlamaTokenizerFast
import torch
import os
@@ -17,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
return out
@@ -73,6 +77,14 @@ class HunyuanVideoTokenizer:
return {}
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()

View File

@@ -32,6 +32,29 @@ class Llama2Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Mistral3Small24BConfig:
vocab_size: int = 131072
hidden_size: int = 5120
intermediate_size: int = 32768
num_hidden_layers: int = 40
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_3BConfig:
@@ -53,6 +76,29 @@ class Qwen25_3BConfig:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_7BVLI_Config:
@@ -74,6 +120,7 @@ class Qwen25_7BVLI_Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Gemma2_2B_Config:
@@ -96,6 +143,7 @@ class Gemma2_2B_Config:
k_norm = None
sliding_attention = None
rope_scale = None
final_norm: bool = True
@dataclass
class Gemma3_4B_Config:
@@ -118,6 +166,7 @@ class Gemma3_4B_Config:
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
final_norm: bool = True
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -366,7 +415,12 @@ class Llama2_(nn.Module):
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
@@ -402,8 +456,12 @@ class Llama2_(nn.Module):
intermediate = None
all_intermediate = None
only_layers = None
if intermediate_output is not None:
if intermediate_output == "all":
if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
@@ -411,7 +469,8 @@ class Llama2_(nn.Module):
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@@ -421,14 +480,17 @@ class Llama2_(nn.Module):
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if self.norm is not None:
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate:
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
return x, intermediate
@@ -453,6 +515,15 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Mistral3Small24B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Mistral3Small24BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@@ -462,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@@ -179,36 +179,36 @@
"special": false
},
"151665": {
"content": "<|img|>",
"content": "<tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
},
"151666": {
"content": "<|endofimg|>",
"content": "</tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
},
"151667": {
"content": "<|meta|>",
"content": "<think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
},
"151668": {
"content": "<|endofmeta|>",
"content": "</think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
}
},
"additional_special_tokens": [

View File

@@ -17,12 +17,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if text.startswith('<|start_header_id|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text

View File

@@ -0,0 +1,48 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar
def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
@@ -671,6 +675,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def z_image_to_diffusers(mmdit_config, output_prefix=""):
n_layers = mmdit_config.get("n_layers", 0)
hidden_size = mmdit_config.get("dim", 0)
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
key_map = {}
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
for end in ("weight", "bias"):
k = "{}.attention.".format(prefix_from)
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attention.norm_q.weight": "attention.q_norm.weight",
"attention.norm_k.weight": "attention.k_norm.weight",
"attention.to_out.0.weight": "attention.out.weight",
"attention.to_out.0.bias": "attention.out.bias",
"attention_norm1.weight": "attention_norm1.weight",
"attention_norm2.weight": "attention_norm2.weight",
"feed_forward.w1.weight": "feed_forward.w1.weight",
"feed_forward.w2.weight": "feed_forward.w2.weight",
"feed_forward.w3.weight": "feed_forward.w3.weight",
"ffn_norm1.weight": "ffn_norm1.weight",
"ffn_norm2.weight": "ffn_norm2.weight",
}
if has_adaln:
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
for k, v in block_map.items():
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
for i in range(n_layers):
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
for i in range(n_context_refiner):
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
for i in range(n_noise_refiner):
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
MAP_BASIC = [
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
("x_embedder.weight", "all_x_embedder.2-1.weight"),
("x_embedder.bias", "all_x_embedder.2-1.bias"),
("x_pad_token", "x_pad_token"),
("cap_embedder.0.weight", "cap_embedder.0.weight"),
("cap_embedder.1.weight", "cap_embedder.1.weight"),
("cap_embedder.1.bias", "cap_embedder.1.bias"),
("cap_pad_token", "cap_pad_token"),
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
]
for c, diffusers in MAP_BASIC:
key_map[diffusers] = "{}{}".format(output_prefix, c)
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
@@ -1102,3 +1172,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
dim=1
)
return out
def pack_latents(latents):
latent_shapes = []
tensors = []
for tensor in latents:
latent_shapes.append(tensor.shape)
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
latent = torch.cat(tensors, dim=-1)
return latent, latent_shapes
def unpack_latents(combined_latent, latent_shapes):
if len(latent_shapes) > 1:
output_tensors = []
for shape in latent_shapes:
cut = math.prod(shape[1:])
tens = combined_latent[:, :, :cut]
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
return output_tensors

View File

@@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None:
weight = weight_decompose(
dora_scale,

View File

@@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args
from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker:
@@ -220,11 +220,18 @@ class AsyncToSyncConverter:
self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# Get all annotations from the class hierarchy and resolve string annotations
try:
# get_type_hints resolves string annotations to actual type objects
# This handles classes using 'from __future__ import annotations'
all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations if get_type_hints fails
# (e.g., for undefined forward references)
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items():
@@ -625,15 +632,19 @@ class AsyncToSyncConverter:
"""Extract class attributes that are classes themselves."""
class_attributes = []
# Get resolved type hints to handle string annotations
try:
type_hints = get_type_hints(async_class)
except Exception:
type_hints = {}
# Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr))
elif (
hasattr(async_class, "__annotations__")
and name in async_class.__annotations__
):
annotation = async_class.__annotations__[name]
elif name in type_hints:
# Use resolved type hint instead of raw annotation
annotation = type_hints[name]
if isinstance(annotation, type):
class_attributes.append((name, annotation))
@@ -908,11 +919,15 @@ class AsyncToSyncConverter:
attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# Resolve string annotations to actual types
try:
all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes:

View File

@@ -7,9 +7,9 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@@ -104,6 +104,8 @@ class Types:
VideoCodec = VideoCodec
VideoContainer = VideoContainer
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
ComfyAPI = ComfyAPI_latest
@@ -114,6 +116,10 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui
IO = io
UI = ui
__all__ = [
"ComfyAPI",
"ComfyAPISync",
@@ -121,4 +127,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"IO",
"ui",
"UI",
]

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@@ -23,7 +24,7 @@ class VideoInput(ABC):
@abstractmethod
def save_to(
self,
path: str,
path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
@@ -72,6 +73,33 @@ class VideoInput(ABC):
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.
Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.
Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])
def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.
Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.
Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

View File

@@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)
# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
# Last resort: match get_components_internal default
return Fraction(1)
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
@@ -264,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():

View File

@@ -27,6 +27,7 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
@@ -336,11 +337,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput):
"""Combo input (dropdown)."""
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
upload: UploadType=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None):
def __init__(
self,
id: str,
options: list[str] | list[int] | type[Enum] = None,
display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False
self.options = options
@@ -614,6 +629,10 @@ class UpscaleModel(ComfyTypeIO):
if TYPE_CHECKING:
Type = ImageModelDescriptor
@comfytype(io_type="LATENT_UPSCALE_MODEL")
class LatentUpscaleModel(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO")
class Audio(ComfyTypeIO):
class AudioDict(TypedDict):
@@ -642,11 +661,11 @@ class LossMap(ComfyTypeIO):
@comfytype(io_type="VOXEL")
class Voxel(ComfyTypeIO):
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
Type = VOXEL
@comfytype(io_type="MESH")
class Mesh(ComfyTypeIO):
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
Type = MESH
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
@@ -1568,78 +1587,78 @@ class _UIOutput(ABC):
...
class _IO:
FolderType = FolderType
UploadType = UploadType
RemoteOptions = RemoteOptions
NumberDisplay = NumberDisplay
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
comfytype = staticmethod(comfytype)
Custom = staticmethod(Custom)
Input = Input
WidgetInput = WidgetInput
Output = Output
ComfyTypeI = ComfyTypeI
ComfyTypeIO = ComfyTypeIO
#---------------------------------
"comfytype",
"Custom",
"Input",
"WidgetInput",
"Output",
"ComfyTypeI",
"ComfyTypeIO",
# Supported Types
Boolean = Boolean
Int = Int
Float = Float
String = String
Combo = Combo
MultiCombo = MultiCombo
Image = Image
WanCameraEmbedding = WanCameraEmbedding
Webcam = Webcam
Mask = Mask
Latent = Latent
Conditioning = Conditioning
Sampler = Sampler
Sigmas = Sigmas
Noise = Noise
Guider = Guider
Clip = Clip
ControlNet = ControlNet
Vae = Vae
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoder = AudioEncoder
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel
Audio = Audio
Video = Video
SVG = SVG
LoraModel = LoraModel
LossMap = LossMap
Voxel = Voxel
Mesh = Mesh
Hooks = Hooks
HookKeyframes = HookKeyframes
TimestepsRange = TimestepsRange
LatentOperation = LatentOperation
FlowControl = FlowControl
Accumulation = Accumulation
Load3DCamera = Load3DCamera
Load3D = Load3D
Load3DAnimation = Load3DAnimation
Photomaker = Photomaker
Point = Point
FaceAnalysis = FaceAnalysis
BBOX = BBOX
SEGS = SEGS
AnyType = AnyType
MultiType = MultiType
#---------------------------------
HiddenHolder = HiddenHolder
Hidden = Hidden
NodeInfoV1 = NodeInfoV1
NodeInfoV3 = NodeInfoV3
Schema = Schema
ComfyNode = ComfyNode
NodeOutput = NodeOutput
add_to_dict_v1 = staticmethod(add_to_dict_v1)
add_to_dict_v3 = staticmethod(add_to_dict_v3)
"Boolean",
"Int",
"Float",
"String",
"Combo",
"MultiCombo",
"Image",
"WanCameraEmbedding",
"Webcam",
"Mask",
"Latent",
"Conditioning",
"Sampler",
"Sigmas",
"Noise",
"Guider",
"Clip",
"ControlNet",
"Vae",
"Model",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",
"Gligen",
"UpscaleModel",
"Audio",
"Video",
"SVG",
"LoraModel",
"LossMap",
"Voxel",
"Mesh",
"Hooks",
"HookKeyframes",
"TimestepsRange",
"LatentOperation",
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3D",
"Load3DAnimation",
"Photomaker",
"Point",
"FaceAnalysis",
"BBOX",
"SEGS",
"AnyType",
"MultiType",
# Other classes
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
]

View File

@@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText
__all__ = [
"SavedResult",
"SavedImages",
"SavedAudios",
"ImageSaveHelper",
"AudioSaveHelper",
"PreviewImage",
"PreviewMask",
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewText",
]

View File

@@ -1,8 +1,11 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",
"VOXEL",
"MESH",
]

View File

@@ -0,0 +1,12 @@
import torch
class VOXEL:
def __init__(self, data: torch.Tensor):
self.data = data
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces

View File

@@ -1,704 +0,0 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.input.video_types import VideoInput
from comfy_api.input.basic_types import AudioInput
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
import torch
import math
import base64
import uuid
from io import BytesIO
import av
async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
video_url: The URL of the video to download.
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
raise ValueError(error_msg)
return VideoFromFile(video_io)
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(
input_tensor.unsqueeze(0), total_pixels=total_pixels
).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = io.BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data.
"""
if not mime_type:
mime_type = "image/png"
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = (
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
)
return img_binary
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
Args:
video: VideoInput object (Comfy VIDEO type).
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
Returns:
The download URL for the uploaded video file.
"""
if max_duration is not None:
try:
actual_duration = video.duration_seconds
if actual_duration is not None and actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error(f"Error getting video duration: {e}")
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = io.BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)

View File

@@ -1,17 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@@ -1,57 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[str] = Field(
None, description='Negative prompt\n', max_length=2048
)
prompt: str = Field(..., description='Prompt', max_length=2048)
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

View File

@@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
class BFLFluxCannyImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxDepthImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
@@ -108,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
# )
class Flux2ProGenerateRequest(BaseModel):
prompt: str = Field(...)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
seed: int | None = Field(None)
prompt_upsampling: bool | None = Field(None)
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
@@ -147,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description='The unique identifier for the generation task.')
polling_url: str = Field(..., description='URL to poll for the generation result.')
id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description="URL to poll for the generation result.")
cost: float | None = Field(None, description="Price in cents")
class BFLStatus(str, Enum):
@@ -160,15 +146,8 @@ class BFLStatus(str, Enum):
error = "Error"
class BFLFluxProStatusResponse(BaseModel):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(
None, description="The result of the task (null if not completed)."
)
progress: confloat(ge=0.0, le=1.0) = Field(
..., description="The progress of the task (0.0 to 1.0)."
)
details: Optional[Dict[str, Any]] = Field(
None, description="Additional details about the task (null if not available)."
)
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)

View File

@@ -1,963 +0,0 @@
"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
# Example 2: Asynchronous API Operation with Polling
# -------------------------------------------------
# For an API that starts a task and requires polling for completion:
# 1. Define the endpoints (initial request and polling)
generate_image_endpoint = ApiEndpoint(
path="/v1/images/generate",
method=HttpMethod.POST,
request_model=ImageGenerationRequest,
response_model=TaskCreatedResponse,
query_params=None
)
check_task_endpoint = ApiEndpoint(
path="/v1/tasks/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=ImageGenerationResult,
query_params=None
)
# 2. Create the request object
request = ImageGenerationRequest(
prompt="a beautiful sunset over mountains",
width=1024,
height=1024,
num_images=1
)
# 3. Create and execute the polling operation
operation = PollingOperation(
initial_endpoint=generate_image_endpoint,
initial_request=request,
poll_endpoint=check_task_endpoint,
task_id_field="task_id",
status_field="status",
completed_statuses=["completed"],
failed_statuses=["failed", "error"]
)
# This will make the initial request and then poll until completion
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
"""
from __future__ import annotations
import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
self._session: Optional[aiohttp.ClientSession] = session
self._owns_session = session is None # Track if we have to close it
@staticmethod
def _generate_operation_id(path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
@staticmethod
def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
return {
"json": data,
"headers": headers,
}
def _create_form_data_args(
self,
data: Dict[str, Any] | None,
files: Dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> Dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
if isinstance(data, aiohttp.FormData):
form = data # If the parser already returned a FormData, pass it through
else:
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue # aiohttp fails to serialize "None" values
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
if isinstance(file_obj, tuple):
filename, file_value, content_type = self._unpack_tuple(file_obj)
else:
file_value = file_obj
filename = getattr(file_obj, "name", field_name)
content_type = "application/octet-stream"
form.add_field(
name=field_name,
value=file_value,
filename=filename,
content_type=content_type,
)
return {"data": form, "headers": headers or {}}
@staticmethod
def _create_urlencoded_form_data_args(
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
"data": data,
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
results["internet_accessible"] = resp.status < 500
except (ClientError, asyncio.TimeoutError, socket.gaierror):
results["is_local_issue"] = True
return results # cannot reach the internet early exit
# Now check API health endpoint
parsed = urlparse(target_url)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
try:
async with session.get(health_url, ssl=self.verify_ssl) as resp:
results["api_accessible"] = resp.status < 500
except ClientError:
pass # leave as False
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
return results
async def request(
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
data: body data
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
# Build full URL and merge headers
relative_path = path.lstrip("/")
url = urljoin(self.base_url, relative_path)
self._check_auth(self.auth_token, self.comfy_api_key)
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
if files:
request_headers.pop("Content-Type", None)
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
elif content_type == "multipart/form-data":
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]",
)
session = await self._get_session()
try:
async with session.request(
method,
url,
params=params,
ssl=self.verify_ssl,
**payload_args,
) as resp:
if resp.status >= 400:
try:
error_data = await resp.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
error_data = await resp.text()
return await self._handle_http_error(
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
operation_id,
method,
url,
params,
data,
files,
headers,
content_type,
multipart_parser,
retry_count=retry_count,
response_content=error_data,
)
# Success parse JSON (safely) and log
try:
payload = await resp.json()
response_content_to_log = payload
except (aiohttp.ContentTypeError, json.JSONDecodeError):
payload = {}
response_content_to_log = await resp.text()
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
# Treat as *connection* problem optionally retry, else escalate
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
self.max_retries, str(e))
await asyncio.sleep(delay)
return await self.request(
method,
path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# One final connectivity check for diagnostics
connectivity = await self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
@staticmethod
def _check_auth(auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token or comfy_api_key
@staticmethod
async def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> aiohttp.ClientResponse:
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: Dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
skip_auto_headers.add("Content-Type")
# Extract file bytes
if isinstance(file, io.BytesIO):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("File must be BytesIO or str path")
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data {len(data)} bytes]",
)
delay = retry_delay
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.put(
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
) as resp:
resp.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return resp
except (ClientError, asyncio.TimeoutError) as e:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
if attempt < max_retries:
logging.warning(
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
)
await asyncio.sleep(delay)
delay *= retry_backoff_factor
else:
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
async def _handle_http_error(
self,
exc: ClientResponseError,
operation_id: str,
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> Dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_friendly = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_friendly = "Rate Limit Exceeded: Please try again later."
else:
if isinstance(response_content, dict):
if "error" in response_content and "message" in response_content["error"]:
user_friendly = f"API Error: {response_content['error']['message']}"
if "type" in response_content["error"]:
user_friendly += f" (Type: {response_content['error']['type']})"
else: # Handle cases where error is just a JSON dict with unknown format
user_friendly = f"API Error: {json.dumps(response_content)}"
else:
if len(response_content) < 200: # Arbitrary limit for display
user_friendly = f"API Error (raw): {response_content}"
else:
user_friendly = f"API Error (raw, status {response_content})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=req_meta[0],
request_url=req_meta[1],
response_status_code=exc.status,
response_headers=dict(req_meta[5]) if req_meta[5] else None,
response_content=response_content,
error_message=f"HTTP Error {exc.status}",
)
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}")
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
"HTTP error %s. Retrying in %.2fs (%s/%s)",
status_code,
delay,
retry_count + 1,
self.max_retries,
)
await asyncio.sleep(delay)
return await self.request(
req_meta[0], # method
req_meta[1].replace(self.base_url, ""), # path
params=req_meta[2],
data=req_meta[3],
files=req_meta[4],
headers=req_meta[5],
content_type=req_meta[6],
multipart_parser=req_meta[7],
retry_count=retry_count + 1,
)
raise Exception(user_friendly) from exc
@staticmethod
def _unpack_tuple(t):
"""Helper to normalise (filename, file, content_type) tuples."""
if len(t) == 3:
return t
elif len(t) == 2:
return t[0], t[1], "application/octet-stream"
else:
raise ValueError("files tuple must be (filename, file[, content_type])")
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
self._owns_session = True
return self._session
async def close(self) -> None:
if self._owns_session and self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ApiClient":
"""Allow usage as asynccontextmanager ensures clean teardown"""
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""Represents a single synchronous API operation."""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> None:
self.endpoint = endpoint
self.request = request
self.files = files
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
request_dict: Optional[Dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
request_dict = self.request.model_dump(exclude_none=True)
for k, v in list(request_dict.items()):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
response_json = await client.request(
self.endpoint.method.value,
self.endpoint.path,
params=self.endpoint.query_params,
data=request_dict,
files=self.files,
content_type=self.content_type,
multipart_parser=self.multipart_parser,
)
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
return parsed_response
finally:
if owns_client:
await client.close()
class TaskStatus(str, Enum):
"""Enum for task status values"""
COMPLETED = "completed"
FAILED = "failed"
PENDING = "pending"
class PollingOperation(Generic[T, R]):
"""Represents an asynchronous API operation that requires polling for completion."""
def __init__(
self,
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] | None = None,
result_url_extractor: Callable[[R], str] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
) -> None:
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
return await self._poll_until_complete(client)
finally:
if owns_client:
await client.close()
def _display_text_on_node(self, text: str):
if not self.node_id:
return
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
if not self.node_id:
return
if self.estimated_duration is not None:
remaining = max(0, int(self.estimated_duration) - time_completed)
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
else:
message = f"Task in progress: {time_completed}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
try:
status = self.status_extractor(response)
if status in self.completed_statuses:
return TaskStatus.COMPLETED
if status in self.failed_statuses:
return TaskStatus.FAILED
return TaskStatus.PENDING
except Exception as e:
logging.error("Error extracting status: %s", e)
return TaskStatus.PENDING
async def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
)
# Query task status
resp = await client.request(
self.poll_endpoint.method.value,
self.poll_endpoint.path,
params=self.poll_endpoint.query_params,
data=request_dict,
)
consecutive_errors = 0 # reset on success
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
# If progress extractor is provided, extract progress
if self.progress_extractor:
new_progress = self.progress_extractor(response_obj)
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
for i in range(int(self.poll_interval)):
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
await asyncio.sleep(1)
except (LocalNetworkError, ApiServerError, NetworkError) as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
await asyncio.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
"The operation may still be running on the server but is taking longer than expected."
)

View File

@@ -1,19 +1,236 @@
from __future__ import annotations
from datetime import date
from enum import Enum
from typing import Any
from typing import List, Optional
from pydantic import BaseModel, Field
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiSafetyCategory(str, Enum):
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class GeminiSafetyThreshold(str, Enum):
OFF = "OFF"
BLOCK_NONE = "BLOCK_NONE"
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
class GeminiSafetySetting(BaseModel):
category: GeminiSafetyCategory
threshold: GeminiSafetyThreshold
class GeminiRole(str, Enum):
user = "user"
model = "model"
class GeminiMimeType(str, Enum):
application_pdf = "application/pdf"
audio_mpeg = "audio/mpeg"
audio_mp3 = "audio/mp3"
audio_wav = "audio/wav"
image_png = "image/png"
image_jpeg = "image/jpeg"
image_webp = "image/webp"
text_plain = "text/plain"
video_mov = "video/mov"
video_mpeg = "video/mpeg"
video_mp4 = "video/mp4"
video_mpg = "video/mpg"
video_avi = "video/avi"
video_wmv = "video/wmv"
video_mpegps = "video/mpegps"
video_flv = "video/flv"
class GeminiInlineData(BaseModel):
data: str | None = Field(
None,
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
)
mimeType: GeminiMimeType | None = Field(None)
class GeminiFileData(BaseModel):
fileUri: str | None = Field(None)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
class GeminiTextPart(BaseModel):
text: str | None = Field(None)
class GeminiContent(BaseModel):
parts: list[GeminiPart] = Field([])
role: GeminiRole = Field(..., examples=["user"])
class GeminiSystemInstructionContent(BaseModel):
parts: list[GeminiTextPart] = Field(
...,
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
class GeminiFunctionDeclaration(BaseModel):
description: str | None = Field(None)
name: str = Field(...)
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
class GeminiTool(BaseModel):
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
class GeminiOffset(BaseModel):
nanos: int | None = Field(None, ge=0, le=999999999)
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
class GeminiVideoMetadata(BaseModel):
endOffset: GeminiOffset | None = Field(None)
startOffset: GeminiOffset | None = Field(None)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageConfig(BaseModel):
aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None)
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
responseModalities: list[str] | None = Field(None)
imageConfig: GeminiImageConfig | None = Field(None)
class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None
contents: list[GeminiContent] = Field(...)
generationConfig: GeminiImageGenerationConfig | None = Field(None)
safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
class GeminiGenerateContentRequest(BaseModel):
contents: list[GeminiContent] = Field(...)
generationConfig: GeminiGenerationConfig | None = Field(None)
safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
class Modality(str, Enum):
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
TEXT = "TEXT"
IMAGE = "IMAGE"
VIDEO = "VIDEO"
AUDIO = "AUDIO"
DOCUMENT = "DOCUMENT"
class ModalityTokenCount(BaseModel):
modality: Modality | None = None
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
class Probability(str, Enum):
NEGLIGIBLE = "NEGLIGIBLE"
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
UNKNOWN = "UNKNOWN"
class GeminiSafetyRating(BaseModel):
category: GeminiSafetyCategory | None = None
probability: Probability | None = Field(
None,
description="The probability that the content violates the specified safety category",
)
class GeminiCitation(BaseModel):
authors: list[str] | None = None
endIndex: int | None = None
license: str | None = None
publicationDate: date | None = None
startIndex: int | None = None
title: str | None = None
uri: str | None = None
class GeminiCitationMetadata(BaseModel):
citations: list[GeminiCitation] | None = None
class GeminiCandidate(BaseModel):
citationMetadata: GeminiCitationMetadata | None = None
content: GeminiContent | None = None
finishReason: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiPromptFeedback(BaseModel):
blockReason: str | None = None
blockReasonMessage: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiUsageMetadata(BaseModel):
cachedContentTokenCount: int | None = Field(
None,
description="Output only. Number of tokens in the cached part in the input (the cached content).",
)
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of candidate tokens by modality."
)
promptTokenCount: int | None = Field(
None,
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
)
promptTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of prompt tokens by modality."
)
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
class GeminiGenerateContentResponse(BaseModel):
candidates: list[GeminiCandidate] | None = Field(None)
promptFeedback: GeminiPromptFeedback | None = Field(None)
usageMetadata: GeminiUsageMetadata | None = Field(None)
modelVersion: str | None = Field(None)

View File

@@ -0,0 +1,66 @@
from pydantic import BaseModel, Field
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
class OmniParamImage(BaseModel):
image_url: str = Field(...)
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
class OmniParamVideo(BaseModel):
video_url: str = Field(...)
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
keep_original_sound: str = Field(..., description="'yes' or 'no'")
class OmniProFirstLastFrameRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
class OmniProReferences2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
image_list: list[OmniParamImage] | None = Field(
None, max_length=7, description="Max length 4 when video is present."
)
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
class TaskStatusVideoResult(BaseModel):
duration: str | None = Field(None, description="Total video duration")
id: str | None = Field(None, description="Generated video ID")
url: str | None = Field(None, description="URL for generated video")
class TaskStatusVideoResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
class TaskStatusVideoResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: TaskStatusVideoResults | None = Field(None)
class TaskStatusVideoResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: TaskStatusVideoResponseData | None = Field(None)

View File

@@ -0,0 +1,120 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@@ -0,0 +1,133 @@
from typing import Optional, Union
from pydantic import BaseModel, Field
class ImageEnhanceRequest(BaseModel):
model: str = Field("Reimagine")
output_format: str = Field("jpeg")
subject_detection: str = Field("All")
face_enhancement: bool = Field(True)
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
source_url: str = Field(...)
output_width: Optional[int] = Field(None)
output_height: Optional[int] = Field(None)
crop_to_fill: bool = Field(False)
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
face_preservation: str = Field("true", description="To preserve the identity of characters")
color_preservation: str = Field("true", description="To preserve the original color")
class ImageAsyncTaskResponse(BaseModel):
process_id: str = Field(...)
class ImageStatusResponse(BaseModel):
process_id: str = Field(...)
status: str = Field(...)
progress: Optional[int] = Field(None)
credits: int = Field(...)
class ImageDownloadResponse(BaseModel):
download_url: str = Field(...)
expiry: int = Field(...)
class Resolution(BaseModel):
width: int = Field(...)
height: int = Field(...)
class CreateCreateVideoRequestSource(BaseModel):
container: str = Field(...)
size: int = Field(..., description="Size of the video file in bytes")
duration: int = Field(..., description="Duration of the video file in seconds")
frameCount: int = Field(..., description="Total number of frames in the video")
frameRate: int = Field(...)
resolution: Resolution = Field(...)
class VideoFrameInterpolationFilter(BaseModel):
model: str = Field(...)
slowmo: Optional[int] = Field(None)
fps: int = Field(...)
duplicate: bool = Field(...)
duplicate_threshold: float = Field(...)
class VideoEnhancementFilter(BaseModel):
model: str = Field(...)
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
compression: Optional[float] = Field(None, description="Strength of compression recovery")
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
noise: Optional[float] = Field(None, description="Amount of noise reduction")
halo: Optional[float] = Field(None, description="Amount of halo reduction")
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
class OutputInformationVideo(BaseModel):
resolution: Resolution = Field(...)
frameRate: int = Field(...)
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
audioTransfer: str = Field(..., description="Copy, Convert, None")
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
class Overrides(BaseModel):
isPaidDiffusion: bool = Field(True)
class CreateVideoRequest(BaseModel):
source: CreateCreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
class CreateVideoResponse(BaseModel):
requestId: str = Field(...)
class VideoAcceptResponse(BaseModel):
uploadId: str = Field(...)
urls: list[str] = Field(...)
class VideoCompleteUploadRequestPart(BaseModel):
partNum: int = Field(...)
eTag: str = Field(...)
class VideoCompleteUploadRequest(BaseModel):
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
class VideoCompleteUploadResponse(BaseModel):
message: str = Field(..., description="Confirmation message")
class VideoStatusResponseEstimates(BaseModel):
cost: list[int] = Field(...)
class VideoStatusResponseDownloadUrl(BaseModel):
url: str = Field(...)
class VideoStatusResponse(BaseModel):
status: str = Field(...)
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
progress: Optional[float] = Field(None)
message: Optional[str] = Field("")
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)

View File

@@ -1,13 +1,20 @@
from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom"

View File

@@ -0,0 +1,99 @@
from typing import Optional
from pydantic import BaseModel, Field
class VeoRequestInstanceImage(BaseModel):
bytesBase64Encoded: str | None = Field(None)
gcsUri: str | None = Field(None)
mimeType: str | None = Field(None)
class VeoRequestInstance(BaseModel):
image: VeoRequestInstanceImage | None = Field(None)
lastFrame: VeoRequestInstanceImage | None = Field(None)
prompt: str = Field(..., description='Text description of the video')
class VeoRequestParameters(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None
generateAudio: Optional[bool] = Field(
None,
description='Generate audio for the video. Only supported by veo 3 models.',
)
negativePrompt: Optional[str] = None
personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
sampleCount: Optional[int] = None
seed: Optional[int] = None
storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video'
)
resolution: str | None = Field(None)
class VeoGenVidRequest(BaseModel):
instances: list[VeoRequestInstance] | None = Field(None)
parameters: VeoRequestParameters | None = Field(None)
class VeoGenVidResponse(BaseModel):
name: str = Field(
...,
description='Operation resource name',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
],
)
class VeoGenVidPollRequest(BaseModel):
operationName: str = Field(
...,
description='Full operation name (from predict response)',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
],
)
class Video(BaseModel):
bytesBase64Encoded: Optional[str] = Field(
None, description='Base64-encoded video content'
)
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
mimeType: Optional[str] = Field(None, description='Video MIME type')
class Error1(BaseModel):
code: Optional[int] = Field(None, description='Error code')
message: Optional[str] = Field(None, description='Error message')
class Response1(BaseModel):
field_type: Optional[str] = Field(
None,
alias='@type',
examples=[
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
],
)
raiMediaFilteredCount: Optional[int] = Field(
None, description='Count of media filtered by responsible AI policies'
)
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
class VeoGenVidPollResponse(BaseModel):
done: Optional[bool] = None
error: Optional[Error1] = Field(
None, description='Error details if operation failed'
)
name: Optional[str] = None
response: Optional[Response1] = Field(
None, description='The actual prediction response if done is true'
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.latest import IO, ComfyExtension
from PIL import Image
import numpy as np
import torch
@@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request,
IdeogramV3EditRequest,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
@@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
@@ -233,89 +227,76 @@ async def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(comfy_io.ComfyNode):
class IdeogramV1(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@@ -334,77 +315,63 @@ class IdeogramV1(comfy_io.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV2(comfy_io.ComfyNode):
class IdeogramV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=list(V1_V1_RES_MAP.keys()),
default="Auto",
@@ -412,44 +379,44 @@ class IdeogramV2(comfy_io.ComfyNode):
"If not set to AUTO, this overrides the aspect_ratio setting.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"style_type",
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
default="NONE",
tooltip="Style type for generation (V2 only)",
optional=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
#"color_palette": (
@@ -462,12 +429,12 @@ class IdeogramV2(comfy_io.ComfyNode):
#),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@@ -500,18 +467,11 @@ class IdeogramV2(comfy_io.ComfyNode):
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
@@ -519,36 +479,28 @@ class IdeogramV2(comfy_io.ComfyNode):
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV3(comfy_io.ComfyNode):
class IdeogramV3(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="IdeogramV3",
display_name="Ideogram V3",
category="api node/image/Ideogram",
@@ -556,30 +508,30 @@ class IdeogramV3(comfy_io.ComfyNode):
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation or editing",
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Optional reference image for image editing.",
optional=True,
),
comfy_io.Mask.Input(
IO.Mask.Input(
"mask",
tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=list(V3_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=V3_RESOLUTIONS,
default="Auto",
@@ -587,57 +539,57 @@ class IdeogramV3(comfy_io.ComfyNode):
"If not set to Auto, this overrides the aspect_ratio setting.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"rendering_speed",
options=["DEFAULT", "TURBO", "QUALITY"],
default="DEFAULT",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"character_image",
tooltip="Image to use as character reference.",
optional=True,
),
comfy_io.Mask.Input(
IO.Mask.Input(
"character_mask",
tooltip="Optional mask for character reference image.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@@ -656,10 +608,6 @@ class IdeogramV3(comfy_io.ComfyNode):
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
@@ -694,9 +642,6 @@ class IdeogramV3(comfy_io.ComfyNode):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask
input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension
@@ -749,27 +694,20 @@ class IdeogramV3(comfy_io.ComfyNode):
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
response_model=IdeogramGenerateResponse,
data=edit_request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request
gen_request = IdeogramV3Request(
prompt=prompt,
@@ -800,43 +738,34 @@ class IdeogramV3(comfy_io.ComfyNode):
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
IdeogramV1,
IdeogramV2,
IdeogramV3,
]
async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,199 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op_raw,
upload_images_to_comfyapi,
validate_string,
)
MODELS_MAP = {
"LTX-2 (Pro)": "ltx-2-pro",
"LTX-2 (Fast)": "ltx-2-fast",
}
class ExecuteTaskRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
class TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
data=ExecuteTaskRequest(
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
data=ExecuteTaskRequest(
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TextToVideoNode,
ImageToVideoNode,
]
async def comfy_entrypoint() -> LtxvApiExtension:
return LtxvApiExtension()

View File

@@ -1,75 +1,57 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef,
LumaModifyImageRef,
LumaConceptChain,
LumaGeneration,
LumaGenerationRequest,
LumaImageGenerationRequest,
LumaImageIdentity,
LumaImageModel,
LumaImageReference,
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaReference,
LumaReferenceChain,
LumaImageReference,
LumaKeyframes,
LumaConceptChain,
LumaIO,
LumaVideoModel,
LumaVideoModelOutputDuration,
LumaVideoOutputResolution,
get_luma_concepts,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
process_image_response,
validate_string,
)
from server import PromptServer
import aiohttp
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(comfy_io.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
class LumaReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Image to use as reference.",
),
comfy_io.Float.Input(
IO.Float.Input(
"weight",
default=1.0,
min=0.0,
@@ -77,72 +59,56 @@ class LumaReferenceNode(comfy_io.ComfyNode):
step=0.01,
tooltip="Weight of image reference.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
IO.Custom(LumaIO.LUMA_REF).Input(
"luma_ref",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
)
@classmethod
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> comfy_io.NodeOutput:
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return comfy_io.NodeOutput(luma_ref)
return IO.NodeOutput(luma_ref)
class LumaConceptsNode(comfy_io.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
class LumaConceptsNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"concept1",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
IO.Combo.Input(
"concept2",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
IO.Combo.Input(
"concept3",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
IO.Combo.Input(
"concept4",
options=get_luma_concepts(include_none=True),
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to add to the ones chosen here.",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
)
@classmethod
@@ -153,42 +119,38 @@ class LumaConceptsNode(comfy_io.ComfyNode):
concept3: str,
concept4: str,
luma_concepts: LumaConceptChain = None,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain)
return comfy_io.NodeOutput(chain)
return IO.NodeOutput(chain)
class LumaImageGenerationNode(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
class LumaImageGenerationNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
options=LumaImageModel,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -196,7 +158,7 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Float.Input(
IO.Float.Input(
"style_image_weight",
default=1.0,
min=0.0,
@@ -204,27 +166,27 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
step=0.01,
tooltip="Weight of style image. Ignored if no style_image provided.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
IO.Custom(LumaIO.LUMA_REF).Input(
"image_luma_ref",
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"style_image",
tooltip="Style reference image; only 1 image will be used.",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"character_image",
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
optional=True,
),
],
outputs=[comfy_io.Image.Output()],
outputs=[IO.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -237,45 +199,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
aspect_ratio: str,
seed,
style_image_weight: float,
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
) -> comfy_io.NodeOutput:
image_luma_ref: Optional[LumaReferenceChain] = None,
style_image: Optional[torch.Tensor] = None,
character_image: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=aspect_ratio,
@@ -283,41 +230,21 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return comfy_io.NodeOutput(img)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
@classmethod
async def _convert_luma_refs(
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = await upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
luma_urls.append(download_urls[0])
ref_count += 1
if ref_count >= max_refs:
@@ -325,38 +252,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
return await cls._convert_luma_refs(chain, max_refs=1)
class LumaImageModifyNode(comfy_io.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
class LumaImageModifyNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"image",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Float.Input(
IO.Float.Input(
"image_weight",
default=0.1,
min=0.0,
@@ -364,11 +283,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
step=0.01,
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
options=LumaImageModel,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -377,11 +296,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
],
outputs=[comfy_io.Image.Output()],
outputs=[IO.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -394,99 +313,68 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
image: torch.Tensor,
image_weight: float,
seed,
) -> comfy_io.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
)
) -> IO.NodeOutput:
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
image_url = download_urls[0]
# next, make Luma call with download url provided
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
prompt=prompt,
model=model,
modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return comfy_io.NodeOutput(img)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
class LumaTextToVideoGenerationNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on prompt and output_size.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
options=LumaVideoModel,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
options=LumaVideoModelOutputDuration,
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -494,17 +382,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
duration: str,
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
) -> comfy_io.NodeOutput:
luma_concepts: Optional[LumaConceptChain] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
prompt=prompt,
model=model,
resolution=resolution,
@@ -545,77 +426,55 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
class LumaImageToVideoGenerationNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
options=LumaVideoModel,
),
# comfy_io.Combo.Input(
# IO.Combo.Input(
# "aspect_ratio",
# options=[ratio.value for ratio in LumaAspectRatio],
# default=LumaAspectRatio.ratio_16_9,
# ),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -623,27 +482,27 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Image.Input(
IO.Image.Input(
"first_image",
tooltip="First frame of generated video.",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"last_image",
tooltip="Last frame of generated video.",
optional=True,
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -660,27 +519,17 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
raise Exception("At least one of first_image and last_image requires an input.")
keyframes = await cls._convert_to_keyframes(first_image, last_image)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
@@ -690,61 +539,38 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
@classmethod
async def _convert_to_keyframes(
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
):
if first_image is None and last_image is None:
return None
frame0 = None
frame1 = None
if first_image is not None:
download_urls = await upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = await upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1)
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LumaImageGenerationNode,
LumaImageModifyNode,

View File

@@ -1,71 +1,57 @@
from inspect import cleandoc
from typing import Optional
import logging
import torch
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
async def _generate_mm_video(
cls: type[IO.ComfyNode],
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -73,95 +59,64 @@ async def _generate_mm_video(
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return comfy_io.NodeOutput(VideoFromFile(video_io))
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
class MinimaxTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["T2V-01", "T2V-01-Director"],
default="T2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -172,11 +127,11 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -187,13 +142,9 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
prompt_text: str,
model: str = "T2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -203,36 +154,32 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
)
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
class MinimaxImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Image to use as first frame of video generation",
),
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
default="I2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -243,11 +190,11 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -259,13 +206,9 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
prompt_text: str,
model: str = "I2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -275,36 +218,32 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
)
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
class MinimaxSubjectToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"subject",
tooltip="Image of subject to reference for video generation",
),
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["S2V-01"],
default="S2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -315,11 +254,11 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -331,13 +270,9 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
prompt_text: str,
model: str = "S2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -347,24 +282,22 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
)
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
class MinimaxHailuoVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation.",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@@ -374,25 +307,25 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
tooltip="The random seed used for creating the noise.",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"first_frame_image",
tooltip="Optional image to use as the first frame to generate a video.",
optional=True,
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"prompt_optimizer",
default=True,
tooltip="Optimize prompt to improve generation quality when needed.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[6, 10],
default=6,
tooltip="The length of the output video in seconds.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=["768P", "1080P"],
default="768P",
@@ -400,11 +333,11 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -419,11 +352,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
duration: int = 6,
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
) -> IO.NodeOutput:
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -453,72 +379,47 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
duration=duration,
resolution=resolution,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return comfy_io.NodeOutput(VideoFromFile(video_io))
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MinimaxTextToVideoNode,
MinimaxImageToVideoNode,

View File

@@ -1,33 +1,30 @@
import logging
from typing import Any, Callable, Optional, TypeVar
from typing import Optional
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api.input import VideoInput
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
import av
import io
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
trim_video,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_image_dimensions,
validate_string,
)
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
@@ -50,13 +47,6 @@ MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
@@ -68,64 +58,7 @@ def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise MoonvalleyApiError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status if response and response.status else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
@@ -144,7 +77,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
_validate_container_format(video)
validate_container_format_is_mp4(video)
return _validate_and_trim_duration(video)
@@ -169,21 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
}
if (width, height) not in supported_resolutions:
supported_list = ", ".join(
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)
def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(
f"Only MP4 container format supported. Got: {container_format}"
)
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
@@ -196,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
@@ -206,127 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
"h264", rate=stream.average_rate
)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream(
"aac", rate=stream.sample_rate
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
@@ -335,7 +134,7 @@ def parse_width_height_from_res(resolution: str):
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
@@ -350,52 +149,47 @@ def parse_control_parameter(value):
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
return await poll_op(
cls,
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
response_model=MoonvalleyPromptResponse,
status_extractor=lambda r: (r.status if r and r.status else None),
poll_interval=16.0,
max_poll_attempts=240,
)
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyImg2VideoNode",
display_name="Moonvalley Marey Image to Video",
category="api node/video/Moonvalley Marey",
description="Moonvalley Marey Image to Video Node",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="The reference image used to generate the video",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
@@ -403,42 +197,43 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
IO.Float.Input(
"prompt_adherence",
default=10.0,
default=4.5,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
tooltip="Number of denoising steps",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -453,22 +248,17 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
@@ -476,85 +266,69 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
)
)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
),
request=request,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
final_response = await get_response(cls, task_creation_response.id)
video = await download_url_to_video_output(final_response.output_url)
return comfy_io.NodeOutput(video)
return IO.NodeOutput(video)
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyVideo2VideoNode",
display_name="Moonvalley Marey Video to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
tooltip="Describes the video to generate",
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=False,
),
comfy_io.Video.Input(
IO.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"control_type",
options=["Motion Transfer", "Pose Transfer"],
default="Motion Transfer",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"motion_intensity",
default=100,
min=0,
@@ -563,12 +337,21 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
IO.Int.Input(
"steps",
default=33,
min=1,
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -582,16 +365,13 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
validate_prompts(prompt, negative_prompt)
video_url = await upload_video_to_comfyapi(cls, validated_video)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
# Only include motion_intensity for Motion Transfer
control_params = {}
@@ -602,65 +382,52 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
control = parse_control_parameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse,
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyVideoToVideoRequest(
control_type=parse_control_parameter(control_type),
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
),
request=request,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return comfy_io.NodeOutput(video)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyTxt2VideoNode",
display_name="Moonvalley Marey Text to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
@@ -673,37 +440,38 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
IO.Float.Input(
"prompt_adherence",
default=10.0,
default=4.0,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
tooltip="Inference steps",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -717,15 +485,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
@@ -735,35 +499,21 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
width=width_height["width"],
height=width_height["height"],
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
init_op = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=auth,
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
)
task_creation_response = await init_op.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return comfy_io.NodeOutput(video)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MoonvalleyImg2VideoNode,
MoonvalleyTxt2VideoNode,

File diff suppressed because it is too large Load Diff

View File

@@ -7,40 +7,23 @@ from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional, TypeVar
from enum import Enum
from typing import Optional
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import (
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
)
from comfy_api_nodes.apis import (
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaGenerateResponse,
PikaVideoResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
EmptyRequest,
HttpMethod,
PollingOperation,
SynchronousOperation,
sync_op,
poll_op,
)
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
@@ -55,152 +38,58 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
class PikaDurationEnum(int, Enum):
integer_5 = 5
integer_10 = 10
class PikaResolutionEnum(str, Enum):
field_1080p = "1080p"
field_720p = "720p"
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaApiError(Exception):
"""Exception for Pika API errors."""
pass
def is_valid_video_response(response: PikaVideoResponse) -> bool:
"""Check if the video response is valid."""
return hasattr(response, "url") and response.url is not None
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
"""Check if the initial response is valid."""
return hasattr(response, "video_id") and response.video_id is not None
async def poll_for_task_status(
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> PikaGenerateResponse:
polling_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PikaVideoResponse,
),
completed_statuses=[
"finished",
],
failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: (
response.status.value if response.status else None
),
progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (
response.url if hasattr(response, "url") else None
),
node_id=node_id,
estimated_duration=60
)
return await polling_operation.execute()
async def execute_task(
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
Args:
initial_operation: The initial operation to execute.
auth_kwargs: The authentication token(s) to use for the API call.
Returns:
A tuple containing the video file as a VIDEO output.
"""
initial_response = await initial_operation.execute()
if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise PikaApiError(error_msg)
task_id = initial_response.video_id
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
)
logging.error(error_msg)
raise PikaApiError(error_msg)
video_url = str(final_response.url)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return (await download_url_to_video_output(video_url),)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[comfy_io.Input]:
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
comfy_io.Combo.Input(
"resolution", options=[resolution.value for resolution in PikaResolutionEnum], default="1080p"
),
comfy_io.Combo.Input(
"duration", options=[duration.value for duration in PikaDurationEnum], default=5
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideoV2_2(comfy_io.ComfyNode):
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
comfy_io.Image.Input("image", tooltip="The image to convert to video"),
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -214,53 +103,40 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
seed: int,
resolution: str,
duration: int,
) -> comfy_io.NodeOutput:
# Convert image to BytesIO
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
# Prepare non-file data
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
comfy_io.Float.Input(
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
@@ -269,11 +145,11 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
tooltip="Aspect ratio (width / height)",
)
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -287,19 +163,12 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
resolution: str,
duration: int,
aspect_ratio: float,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=PikaGenerateResponse,
),
request=PikaBodyGenerate22T2vGenerate22T2vPost(
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -307,30 +176,29 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaScenesV2_2(comfy_io.ComfyNode):
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
comfy_io.Combo.Input(
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
comfy_io.Float.Input(
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
@@ -338,37 +206,37 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
comfy_io.Image.Input(
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
comfy_io.Image.Input(
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
comfy_io.Image.Input(
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
comfy_io.Image.Input(
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
comfy_io.Image.Input(
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -388,8 +256,7 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> comfy_io.NodeOutput:
# Convert all passed images to BytesIO
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
@@ -399,16 +266,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_5,
]:
if image is not None:
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
all_image_bytes_io.append(image_bytes_io)
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -417,53 +282,45 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(comfy_io.ComfyNode):
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
comfy_io.Video.Input("video", tooltip="The video to add an image to."),
comfy_io.Image.Input("image", tooltip="The image to add to the video."),
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input(
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -476,70 +333,70 @@ class PikAdditionsNode(comfy_io.ComfyNode):
prompt_text: str,
negative_prompt: str,
seed: int,
) -> comfy_io.NodeOutput:
# Convert video to BytesIO
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(comfy_io.ComfyNode):
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."),
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"),
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -548,85 +405,65 @@ class PikaSwapsNode(comfy_io.ComfyNode):
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
mask: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> comfy_io.NodeOutput:
# Convert video to BytesIO
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
# Convert mask to binary mask with three channels
mask = torch.round(mask)
mask = mask.repeat(1, 3, 1, 1)
# Convert 3-channel binary mask to BytesIO
mask_bytes_io = BytesIO()
mask_bytes_io.write(mask.numpy().astype(np.uint8))
mask_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(comfy_io.ComfyNode):
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
comfy_io.Combo.Input(
"pikaffect", options=[pikaffect.value for pikaffect in Pikaffect], default="Cake-ify"
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -639,19 +476,12 @@ class PikaffectsNode(comfy_io.ComfyNode):
prompt_text: str,
negative_prompt: str,
seed: int,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=PikaGenerateResponse,
),
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -659,31 +489,30 @@ class PikaffectsNode(comfy_io.ComfyNode):
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
comfy_io.Image.Input("image_start", tooltip="The first image to combine."),
comfy_io.Image.Input("image_end", tooltip="The last image to combine."),
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -698,23 +527,17 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
seed: int,
resolution: str,
duration: int,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=PikaGenerateResponse,
),
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -723,22 +546,21 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
),
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideoV2_2,
PikaTextToVideoNodeV2_2,
PikaScenesV2_2,
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode2_2,
PikaStartEndFrameNode,
]

Some files were not shown because too many files have changed in this diff Show More