Compare commits

..

160 Commits

Author SHA1 Message Date
comfyanonymous
32a95bba8a ComfyUI version 0.3.49 2025-08-05 07:33:02 -04:00
ComfyUI Wiki
da1ad9b516 Update template to 0.1.51 (#9187) 2025-08-05 07:24:12 -04:00
comfyanonymous
d044a24398 Fix default shift and any latent size for qwen image model. (#9186) 2025-08-05 06:12:27 -04:00
ComfyUI Wiki
5be6fd09ff Update template to 0.1.48 (#9182) 2025-08-05 03:48:56 -04:00
Christian Byrne
f69609bbd6 Add Veo3 video generation node with audio support (#9110)
- Create new Veo3VideoGenerationNode that extends VeoVideoGenerationNode
- Add support for generateAudio parameter (only for Veo3 models)
- Support new Veo3 models: veo-3.0-generate-001, veo-3.0-fast-generate-001
- Fix Veo3 duration constraint to 8 seconds only
- Update original node to be clearly Veo 2 only
- Update API paths to use model parameter: /proxy/veo/{model}/generate
- Regenerate API types from staging to include generateAudio parameter
- Fix TripoModelVersion enum reference after regeneration
- Mark generated API types file in .gitattributes
2025-08-05 01:52:25 -04:00
comfyanonymous
c012400240 Initial support for qwen image model. (#9179) 2025-08-04 22:53:25 -04:00
comfyanonymous
03895dea7c Fix another issue with the PR. (#9170) 2025-08-04 04:33:04 -04:00
comfyanonymous
84f9759424 Add some warnings and prevent crash when cond devices don't match. (#9169) 2025-08-04 04:20:12 -04:00
comfyanonymous
7991341e89 Various fixes for broken things from earlier PR. (#9168) 2025-08-04 04:02:40 -04:00
comfyanonymous
140ffc7fdc Fix broken controlnet from last PR. (#9167) 2025-08-04 03:28:12 -04:00
comfyanonymous
182f90b5ec Lower cond vram use by casting at the same time as device transfer. (#9159) 2025-08-04 03:11:53 -04:00
comfyanonymous
aebac22193 Cleanup. (#9160) 2025-08-03 07:08:11 -04:00
comfyanonymous
13aaa66ec2 Make sure context is on the right device. (#9154) 2025-08-02 15:09:23 -04:00
comfyanonymous
5f582a9757 Make sure all the conds are on the right device. (#9151) 2025-08-02 15:00:13 -04:00
ComfyUI Wiki
fbcc23945d Update template to 0.1.47 (#9153) 2025-08-02 14:15:29 -04:00
Johnpaul Chiwetelu
3dfefc88d0 API for Recently Used Items (#8792)
* feat: add file creation time to model file metadata and user file info

* fix linting
2025-08-01 22:02:06 -04:00
comfyanonymous
bff60b5cfc ComfyUI version 0.3.48 2025-08-01 20:03:22 -04:00
comfyanonymous
1e638a140b Tiny wan vae optimizations. (#9136) 2025-08-01 05:25:38 -04:00
ComfyUI Wiki
4696d74305 update template to 0.1.45 (#9135) 2025-08-01 03:06:18 -04:00
comfyanonymous
5ee381c058 Fix WanFirstLastFrameToVideo node when no clip vision. (#9134) 2025-07-31 23:33:27 -04:00
Jedrzej Kosinski
4887743a2a V3 Node Schema Definition - initial (#8656) 2025-07-31 18:02:12 -04:00
comfyanonymous
97b8a2c26a More accurate explanation of release process. (#9126) 2025-07-31 05:46:23 -04:00
guill
97eb256a35 Add support for partial execution in backend (#9123)
When a prompt is submitted, it can optionally include
`partial_execution_targets` as a list of ids. If it does, rather than
adding all outputs to the execution list, we add only those in the list.
2025-07-30 22:55:28 -04:00
chaObserv
61b08d4ba6 Replace manual x * sigmoid(x) with torch silu in VAE nonlinearity (#9057) 2025-07-30 19:25:56 -04:00
comfyanonymous
da9dab7edd Small wan camera memory optimization. (#9111) 2025-07-30 05:55:26 -04:00
ComfyUI Wiki
d2aaef029c Update template to 0.1.44 (#9104) 2025-07-29 22:50:49 -04:00
guill
0a3d062e06 ComfyAPI Core v0.0.2 (#8962)
* ComfyAPI Core v0.0.2

* Respond to PR feedback

* Fix Python 3.9 errors

* Fix missing backward compatibility proxy

* Reorganize types a bit

The input types, input impls, and utility types are now all available in
the versioned API. See the change in `comfy_extras/nodes_video.py` for
an example of their usage.

* Remove the need for `--generate-api-stubs`

* Fix generated stubs differing by Python version

* Fix ruff formatting issues
2025-07-29 22:17:22 -04:00
comfyanonymous
2f74e17975 ComfyUI version 0.3.47 2025-07-29 20:08:25 -04:00
comfyanonymous
dca6bdd4fa Make wan2.2 5B i2v take a lot less memory. (#9102) 2025-07-29 19:44:18 -04:00
comfyanonymous
7d593baf91 Extra reserved vram on large cards on windows. (#9093) 2025-07-29 04:07:45 -04:00
comfyanonymous
c60dc4177c Remove unecessary clones in the wan2.2 VAE. (#9083) 2025-07-28 14:48:19 -04:00
comfyanonymous
5d4cc3ba1b ComfyUI 0.3.46 2025-07-28 08:04:04 -04:00
comfyanonymous
9f1388c0a3 Add wan2.2 to readme. (#9081) 2025-07-28 08:01:53 -04:00
comfyanonymous
a88788dce6 Wan 2.2 support. (#9080) 2025-07-28 08:00:23 -04:00
ComfyUI Wiki
d0210fe2e5 Update template to 0.1.41 (#9079) 2025-07-28 07:55:02 -04:00
Christian Byrne
e6d9f62744 Add Moonvalley Marey V2V node with updated input validation (#9069)
* [moonvalley] Update V2V node to match API specification

- Add exact resolution validation for supported resolutions (1920x1080, 1080x1920, 1152x1152, 1536x1152, 1152x1536)
- Change frame count validation from divisible by 32 to 16
- Add MP4 container format validation
- Remove internal parameters (steps, guidance_scale) from V2V inference params
- Update video duration handling to support only 5 seconds (auto-trim if longer)
- Add motion_intensity parameter (0-100) for Motion Transfer control type
- Add get_container_format() method to VideoInput classes

* update negative prompt
2025-07-27 19:51:36 -04:00
comfyanonymous
78672d0ee6 Small readme update. (#9071) 2025-07-27 07:42:58 -04:00
ComfyUI Wiki
1ef70fcde4 Fix the broken link (#9060) 2025-07-26 17:25:33 -04:00
comfyanonymous
0621d73a9c Remove useless code. (#9059) 2025-07-26 04:44:19 -04:00
comfyanonymous
b850d9a8bb Add map_function to get_history. (#9056) 2025-07-25 21:25:45 -04:00
Thor-ATX
c60467a148 Update negative prompt for Moonvalley nodes (#9038)
Co-authored-by: thorsten <thorsten@tripod-digital.co.nz>
2025-07-25 17:27:03 -04:00
comfyanonymous
c0207b473f Fix issue with line endings github workflow. (#9053) 2025-07-25 17:25:08 -04:00
ComfyUI Wiki
93bc2f8e4d Update template to 0.1.40 (#9048) 2025-07-25 13:24:23 -04:00
comfyanonymous
e6e5d33b35 Remove useless code. (#9041)
This is only needed on old pytorch 2.0 and older.
2025-07-25 04:58:28 -04:00
Eugene Fairley
4293e4da21 Add WAN ATI support (#8874)
* Add WAN ATI support

* Fixes

* Fix length

* Remove extra functions

* Fix

* Fix

* Ruff fix

* Remove torch.no_grad

* Add batch trajectory logic

* Scale inputs before and after motion patch

* Batch image/trajectory

* Ruff fix

* Clean up
2025-07-24 20:59:19 -04:00
comfyanonymous
69cb57b342 Print xpu device name. (#9035) 2025-07-24 15:06:25 -04:00
SHIVANSH GUPTA
d03ae077b4 Added parameter required_frontend_version in the /system_stats API response (#8875)
* Added the parameter required_frontend_version in the /system_stats  api response

* Update server.py

* Created a function get_required_frontend_version and wrote tests for it

* Refactored the function to return currently installed frontend pacakage version

* Moved required_frontend to a new function and imported that in server.py

* Corrected test cases using mocking techniques

* Corrected files to comply with ruff formatting
2025-07-24 14:05:54 -04:00
honglyua
0ccc88b03f Support Iluvatar CoreX (#8585)
* Support Iluvatar CoreX
Co-authored-by: mingjiang.li <mingjiang.li@iluvatar.com>
2025-07-24 13:57:36 -04:00
Kohaku-Blueleaf
eb2f78b4e0 [Training Node] algo support, grad acc, optional grad ckpt (#9015)
* Add factorization utils for lokr

* Add lokr train impl

* Add loha train impl

* Add adapter map for algo selection

* Add optional grad ckpt and algo selection

* Update __init__.py

* correct key name for loha

* Use custom fwd/bwd func and better init for loha

* Support gradient accumulation

* Fix bugs of loha

* use more stable init

* Add OFT training

* linting
2025-07-23 20:57:27 -04:00
chaObserv
e729a5cc11 Separate denoised and noise estimation in Euler CFG++ (#9008)
This will change their behavior with the sampling CONST type.
It also combines euler_cfg_pp and euler_ancestral_cfg_pp into one main function.
2025-07-23 19:47:05 -04:00
comfyanonymous
e78d230496 Only enable cuda malloc on cuda torch. (#9031) 2025-07-23 19:37:43 -04:00
comfyanonymous
d3504e1778 Enable pytorch attention by default for gfx1201 on torch 2.8 (#9029) 2025-07-23 19:21:29 -04:00
comfyanonymous
a86a58c308 Fix xpu function not implemented p2. (#9027) 2025-07-23 18:18:20 -04:00
comfyanonymous
39dda1d40d Fix xpu function not implemented. (#9026) 2025-07-23 18:10:59 -04:00
comfyanonymous
5ad33787de Add default device argument. (#9023) 2025-07-23 14:20:49 -04:00
Simon Lui
255f139863 Add xpu version for async offload and some other things. (#9004) 2025-07-22 15:20:09 -04:00
comfyanonymous
5ac9ec214b Try to fix line endings workflow. (#9001) 2025-07-22 04:07:51 -04:00
comfyanonymous
0aa1c58b04 This is not needed. (#8991) 2025-07-21 16:48:25 -04:00
comfyanonymous
5249e45a1c Add hidream e1.1 example to readme. (#8990) 2025-07-21 15:23:41 -04:00
comfyanonymous
54a45b9967 Replace torchaudio.load with pyav. (#8989) 2025-07-21 14:19:14 -04:00
comfyanonymous
9a470e073e ComfyUI version 0.3.45 2025-07-21 14:05:43 -04:00
ComfyUI Wiki
7d627f764c Update template to 0.1.39 (#8981) 2025-07-20 15:58:35 -04:00
comfyanonymous
a0c0785635 Document what the fast_fp16_accumulation is in the portable. (#8973) 2025-07-20 01:24:09 -04:00
chaObserv
100c2478ea Add SamplingPercentToSigma node (#8963)
It's helpful to adjust start_percent or end_percent based on the corresponding sigma.
2025-07-19 23:09:11 -04:00
ComfyUI Wiki
1da5639e86 Update template to 0.1.37 (#8967) 2025-07-19 06:08:00 -04:00
comfyanonymous
1b96fae1d4 Add nested style of dual cfg to DualCFGGuider node. (#8965) 2025-07-19 04:55:23 -04:00
comfyanonymous
7f492522b6 Forgot this (#8957) 2025-07-18 05:43:02 -04:00
comfyanonymous
650838fd6f Experimental CFGNorm node. (#8942)
This is from the new hidream e1 1 model code. Figured it might be useful as a generic cfg trick.
2025-07-17 04:11:07 -04:00
comfyanonymous
491fafbd64 Silence clip tokenizer warning. (#8934) 2025-07-16 14:42:07 -04:00
Harel Cain
9bc2798f72 LTXV VAE decoder: switch default padding mode (#8930) 2025-07-16 13:54:38 -04:00
comfyanonymous
50afba747c Add attempt to work around the safetensors mmap issue. (#8928) 2025-07-16 03:42:17 -04:00
Brandon Wallace
6b8062f414 Fix MaskComposite error when destination has 2 dimensions (#8915)
Fix code that is using the original `destination` input instead of the reshaped value.
2025-07-15 21:08:27 -04:00
comfyanonymous
b1ae4126c3 Add action to detect windows line endings. (#8917) 2025-07-15 02:27:18 -04:00
Yoland Yan
9dabda19f0 Update nodes_gemini.py (#8912) 2025-07-14 20:59:35 -04:00
Yoland Yan
543c24108c Fix wrong reference bug (#8910) 2025-07-14 20:45:55 -04:00
FeepingCreature
260a5ca5d9 Allow the prompt request to specify the prompt ID. (#8189)
This makes it easier to write asynchronous clients that submit requests, because they can store the task immediately.
Duplicate prompt IDs are rejected by the job queue.
2025-07-14 14:48:31 -04:00
ComfyUI Wiki
861c3bbb3d Upate template to 0.1.36 (#8904) 2025-07-14 13:27:57 -04:00
comfyanonymous
9ca581c941 Remove windows line endings. (#8902) 2025-07-14 13:10:20 -04:00
comfyanonymous
4831e9c2c4 Refactor previous pr. (#8893) 2025-07-13 04:59:17 -04:00
Christian Byrne
480375f349 Remove auth tokens from history storage (#8889)
Remove auth_token_comfy_org and api_key_comfy_org from extra_data before
storing prompt history to prevent sensitive authentication tokens from
being persisted in the history endpoint response.
2025-07-13 04:46:27 -04:00
comfyanonymous
b40143984c Add model detection error hint for lora. (#8880) 2025-07-12 03:49:26 -04:00
chaObserv
b43916a134 Fix fresca's input and output (#8871) 2025-07-11 12:52:58 -04:00
JettHu
7bc7dd2aa2 Execute async node earlier (#8865) 2025-07-11 12:51:06 -04:00
comfyanonymous
938d3e8216 Remove windows line endings. (#8866) 2025-07-11 02:37:51 -04:00
Christian Byrne
8f05fb48ea [fix] increase Kling API polling timeout to prevent user timeouts (#8860)
Extends polling duration from 10 minutes to ~68 minutes (256 attempts × 16 seconds) to accommodate longer Kling API operations that were frequently timing out for users.
2025-07-10 18:00:29 -04:00
comfyanonymous
b7ff5bd14d Fix python3.9 (#8858) 2025-07-10 15:21:18 -04:00
guill
2b653e8c18 Support for async node functions (#8830)
* Support for async execution functions

This commit adds support for node execution functions defined as async. When
a node's execution function is defined as async, we can continue
executing other nodes while it is processing.

Standard uses of `await` should "just work", but people will still have
to be careful if they spawn actual threads. Because torch doesn't really
have async/await versions of functions, this won't particularly help
with most locally-executing nodes, but it does work for e.g. web
requests to other machines.

In addition to the execute function, the `VALIDATE_INPUTS` and
`check_lazy_status` functions can also be defined as async, though we'll
only resolve one node at a time right now for those.

* Add the execution model tests to CI

* Add a missing file

It looks like this got caught by .gitignore? There's probably a better
place to put it, but I'm not sure what that is.

* Add the websocket library for automated tests

* Add additional tests for async error cases

Also fixes one bug that was found when an async function throws an error
after being scheduled on a task.

* Add a feature flags message to reduce bandwidth

We now only send 1 preview message of the latest type the client can
support.

We'll add a console warning when the client fails to send a feature
flags message at some point in the future.

* Add async tests to CI

* Don't actually add new tests in this PR

Will do it in a separate PR

* Resolve unit test in GPU-less runner

* Just remove the tests that GHA can't handle

* Change line endings to UNIX-style

* Avoid loading model_management.py so early

Because model_management.py has a top-level `logging.info`, we have to
be careful not to import that file before we call `setup_logging`. If we
do, we end up having the default logging handler registered in addition
to our custom one.
2025-07-10 14:46:19 -04:00
comfyanonymous
1fd306824d Add warning to catch torch import mistakes. (#8852) 2025-07-10 01:03:27 -04:00
Kohaku-Blueleaf
1205afc708 Better training loop implementation (#8820) 2025-07-09 11:41:22 -04:00
comfyanonymous
5612670ee4 Remove unmaintained notebook. (#8845) 2025-07-09 03:45:48 -04:00
Kohaku-Blueleaf
181a9bf26d Support Multi Image-Caption dataset in lora training node (#8819)
* initial impl of multi img/text dataset

* Update nodes_train.py

* Support Kohya-ss structure
2025-07-08 20:18:04 -04:00
chaObserv
aac10ad23a Add SA-Solver sampler (#8834) 2025-07-08 16:17:06 -04:00
josephrocca
974254218a Un-hardcode chroma patch_size (#8840) 2025-07-08 15:56:59 -04:00
comfyanonymous
c5de4955bb ComfyUI version 0.3.44 2025-07-08 08:56:38 -04:00
Christian Byrne
9fd0cd7cf7 Add Moonvalley nodes (#8832) 2025-07-08 08:54:30 -04:00
ComfyUI Wiki
b5e97db9ac Update template to 0.1.35 (#8831) 2025-07-08 08:52:02 -04:00
Christian Byrne
1359c969e4 Update template to 0.1.34 (#8829) 2025-07-07 23:35:41 -04:00
ComfyUI Wiki
059cd38aa2 Update template and node docs package version (#8825) 2025-07-07 20:43:56 -04:00
comfyanonymous
e740dfd806 Fix warning in audio save nodes. (#8818) 2025-07-07 03:16:00 -04:00
comfyanonymous
7eab7d2944 Remove dependency on deprecated torchaudio.save function (#8815) 2025-07-06 14:01:32 -04:00
comfyanonymous
75d327abd5 Remove some useless code. (#8812) 2025-07-06 07:07:39 -04:00
comfyanonymous
ee615ac269 Add warning when loading file unsafely. (#8800) 2025-07-05 14:34:57 -04:00
comfyanonymous
27870ec3c3 Add that ckpt files are loaded safely to README. (#8791) 2025-07-04 04:49:11 -04:00
chaObserv
f41f323c52 Add the denoising step to several samplers (#8780) 2025-07-03 19:20:53 -04:00
comfyanonymous
f74fc4d927 Add ImageRotate and ImageFlip nodes. (#8789) 2025-07-03 19:16:30 -04:00
ComfyUI Wiki
ae26cd99b5 Update template to 0.1.32 (#8782) 2025-07-03 14:41:16 -04:00
comfyanonymous
e9af97ba1a Use torch cu129 for nvidia pytorch nightly. (#8786)
* update nightly workflow with cu129

* Remove unused file to lower standalone size.
2025-07-03 14:39:11 -04:00
City
d9277301d2 Initial code for new SLG node (#8759) 2025-07-02 20:13:43 -04:00
comfyanonymous
34c8eeec06 Fix ImageColorToMask not returning right mask values. (#8771) 2025-07-02 15:35:11 -04:00
Harel Cain
9f1069290c nodes_lt: fixes to latent conditioning at index > 0 (#8769) 2025-07-02 15:34:51 -04:00
comfyanonymous
111f583e00 Fallback to regular op when fp8 op throws exception. (#8761) 2025-07-02 00:57:13 -04:00
Terry Jia
79ed752748 support upload 3d model to custom subfolder (#8597) 2025-07-01 20:43:48 -04:00
comfyanonymous
772de7c006 PerpNeg Guider optimizations. (#8753) 2025-07-01 03:09:07 -04:00
chaObserv
b22e97dcfa Migrate ER-SDE from VE to VP algorithm and add its sampler node (#8744)
Apply alpha scaling in the algorithm for reverse-time SDE and add custom ER-SDE sampler node for other solver types (SDE, ODE).
2025-07-01 02:38:52 -04:00
chaObserv
f02de13316 Add TCFG node (#8730) 2025-07-01 02:33:07 -04:00
ComfyUI Wiki
c46268bf60 Update requirements.txt (#8741) 2025-06-30 14:18:43 -04:00
comfyanonymous
cf49a2c5b5 Dual cfg node optimizations when cfg is 1.0 (#8747) 2025-06-30 14:18:25 -04:00
comfyanonymous
170c7bb90c Fix contiguous issue with pytorch nightly. (#8729) 2025-06-29 06:38:40 -04:00
bmcomfy
2a0b138feb build: add gh action to process releases (#8652) 2025-06-28 19:11:40 -04:00
comfyanonymous
e195c1b13f Make stable release workflow publish drafts. (#8723) 2025-06-28 19:11:16 -04:00
chaObserv
5b4eb021cb Perpneg guider with updated pre and post-cfg (#8698) 2025-06-28 18:13:13 -04:00
comfyanonymous
396454fa41 Reorder the schedulers so simple is the default one. (#8722) 2025-06-28 18:12:56 -04:00
comfyanonymous
a3cf272522 Skip custom node logic completely if disabled and no whitelisted nodes. (#8719) 2025-06-28 15:53:40 -04:00
xufeng
ba9548f756 “--whitelist-custom-nodes” args for comfy core to go with “--disable-all-custom-nodes” for development purposes (#8592)
* feat: “--whitelist-custom-nodes” args for comfy core to go with “--disable-all-custom-nodes” for development purposes

* feat: Simplify custom nodes whitelist logic to use consistent code paths
2025-06-28 15:24:02 -04:00
comfyanonymous
e18f53cca9 ComfyUI version 0.3.43 2025-06-27 17:22:02 -04:00
comfyanonymous
c36be0ea09 Fix memory estimation bug with kontext. (#8709) 2025-06-27 17:21:12 -04:00
comfyanonymous
9093301a49 Don't add tiny bit of random noise when VAE encoding. (#8705)
Shouldn't change outputs but might make things a tiny bit more
deterministic.
2025-06-27 14:14:56 -04:00
comfyanonymous
bd951a714f Add Flux Kontext and Omnigen 2 models to readme. (#8682) 2025-06-26 12:26:29 -04:00
comfyanonymous
6493709d6a ComfyUI version 0.3.42 2025-06-26 11:47:07 -04:00
filtered
b976f934ae Update frontend to 1.23.4 (#8681) 2025-06-26 11:44:12 -04:00
comfyanonymous
7d8cf4cacc Update requirements.txt (#8680) 2025-06-26 11:39:40 -04:00
filtered
68f4496b8e Update frontend to 1.23.3 (#8678) 2025-06-26 11:29:03 -04:00
comfyanonymous
ef5266b1c1 Support Flux Kontext Dev model. (#8679) 2025-06-26 11:28:41 -04:00
comfyanonymous
a96e65df18 Disable omnigen2 fp16 on older pytorch versions. (#8672) 2025-06-26 03:39:09 -04:00
comfyanonymous
93a49a45de Bump minimum transformers version. (#8671) 2025-06-26 02:33:02 -04:00
comfyanonymous
ec70ed6aea Omnigen2 model implementation. (#8669) 2025-06-25 19:35:57 -04:00
comfyanonymous
7a13f74220 unet -> diffusion model (#8659) 2025-06-25 04:52:34 -04:00
chaObserv
8042eb20c6 Singlestep DPM++ SDE for RF (#8627)
Refactor the algorithm, and apply alpha scaling.
2025-06-24 14:59:09 -04:00
comfyanonymous
bd9f166c12 Cosmos predict2 model merging nodes. (#8647) 2025-06-24 05:17:16 -04:00
comfyanonymous
dd94416db2 Indicate that directml is not recommended in the README. (#8644) 2025-06-23 14:04:49 -04:00
comfyanonymous
ae0e7c4dff Resize and pad image node. (#8636) 2025-06-22 17:59:31 -04:00
comfyanonymous
78f79266a9 Allow padding in ImageStitch node to be white. (#8631) 2025-06-22 00:19:41 -04:00
comfyanonymous
1883e70b43 Fix exception when using a noise mask with cosmos predict2. (#8621)
* Fix exception when using a noise mask with cosmos predict2.

* Fix ruff.
2025-06-21 03:30:39 -04:00
Lucas - BLOCK33
31ca603ccb Improve the log time function for 10 minute + renders (#6207)
* modified:   main.py

* Update main.py
2025-06-20 23:04:55 -04:00
comfyanonymous
f7fb193712 Small flux optimization. (#8611) 2025-06-20 05:37:32 -04:00
comfyanonymous
7e9267fa77 Make flux controlnet work with sd3 text enc. (#8599) 2025-06-19 18:50:05 -04:00
comfyanonymous
91d40086db Fix pytorch warning. (#8593) 2025-06-19 11:04:52 -04:00
coderfromthenorth93
5b12b55e32 Add new fields to the config types (#8507) 2025-06-18 15:12:29 -04:00
comfyanonymous
e9e9a031a8 Show a better error when the workflow OOMs. (#8574) 2025-06-18 06:55:21 -04:00
filtered
d7430c529a Update frontend to 1.22.2 (#8567) 2025-06-17 18:58:28 -04:00
ComfyUI Wiki
cd88f709ab Update template version (#8563) 2025-06-17 04:11:59 -07:00
comfyanonymous
4459a17e82 Add Cosmos Predict2 to README. (#8562) 2025-06-17 05:18:01 -04:00
comfyanonymous
483b3e62e0 ComfyUI version v0.3.41 2025-06-16 23:34:46 -04:00
chaObserv
8e81c507d2 Multistep DPM++ SDE samplers for RF (#8541)
Include alpha in sampling and minor refactoring
2025-06-16 14:47:10 -04:00
comfyanonymous
e1c6dc720e Allow setting min_length with tokenizer_data. (#8547) 2025-06-16 13:43:52 -04:00
comfyanonymous
7ea79ebb9d Add correct eps to ltxv rmsnorm. (#8542) 2025-06-15 12:21:25 -04:00
comfyanonymous
ae75a084df SaveLora now saves in the same filename format as all the other nodes. (#8538) 2025-06-15 03:44:59 -04:00
comfyanonymous
d6a2137fc3 Support Cosmos predict2 image to video models. (#8535)
Use the CosmosPredict2ImageToVideoLatent node.
2025-06-14 21:37:07 -04:00
chaObserv
53e8d8193c Generalize SEEDS samplers (#8529)
Restore VP algorithm for RF and refactor noise_coeffs and half-logSNR calculations
2025-06-14 16:58:16 -04:00
comfyanonymous
29596bd53f Small cosmos attention code refactor. (#8530) 2025-06-14 05:02:05 -04:00
141 changed files with 166413 additions and 1877 deletions

View File

@@ -4,6 +4,9 @@ if you have a NVIDIA gpu:
run_nvidia_gpu.bat
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
run_nvidia_gpu_fast_fp16_accumulation.bat
To run it in slow CPU mode:

1
.gitattributes vendored
View File

@@ -1,2 +1,3 @@
/web/assets/** linguist-generated
/web/** linguist-vendored
comfy_api_nodes/apis/__init__.py linguist-generated

View File

@@ -0,0 +1,40 @@
name: Check for Windows Line Endings
on:
pull_request:
branches: ['*'] # Trigger on all pull requests to any branch
jobs:
check-line-endings:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history to compare changes
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
# Flag to track if CRLF is found
CRLF_FOUND=false
# Loop through each changed file
for FILE in $CHANGED_FILES; do
# Check if the file exists and is a text file
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
# Check for CRLF line endings
if grep -UP '\r$' "$FILE"; then
echo "Error: Windows line endings (CRLF) detected in $FILE"
CRLF_FOUND=true
fi
fi
done
# Exit with error if CRLF was found
if [ "$CRLF_FOUND" = true ]; then
exit 1
fi

108
.github/workflows/release-webhook.yml vendored Normal file
View File

@@ -0,0 +1,108 @@
name: Release Webhook
on:
release:
types: [published]
jobs:
send-webhook:
runs-on: ubuntu-latest
steps:
- name: Send release webhook
env:
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
run: |
# Generate UUID for delivery ID
DELIVERY_ID=$(uuidgen)
HOOK_ID="release-webhook-$(date +%s)"
# Create webhook payload matching GitHub release webhook format
PAYLOAD=$(cat <<EOF
{
"action": "published",
"release": {
"id": ${{ github.event.release.id }},
"node_id": "${{ github.event.release.node_id }}",
"url": "${{ github.event.release.url }}",
"html_url": "${{ github.event.release.html_url }}",
"assets_url": "${{ github.event.release.assets_url }}",
"upload_url": "${{ github.event.release.upload_url }}",
"tag_name": "${{ github.event.release.tag_name }}",
"target_commitish": "${{ github.event.release.target_commitish }}",
"name": ${{ toJSON(github.event.release.name) }},
"body": ${{ toJSON(github.event.release.body) }},
"draft": ${{ github.event.release.draft }},
"prerelease": ${{ github.event.release.prerelease }},
"created_at": "${{ github.event.release.created_at }}",
"published_at": "${{ github.event.release.published_at }}",
"author": {
"login": "${{ github.event.release.author.login }}",
"id": ${{ github.event.release.author.id }},
"node_id": "${{ github.event.release.author.node_id }}",
"avatar_url": "${{ github.event.release.author.avatar_url }}",
"url": "${{ github.event.release.author.url }}",
"html_url": "${{ github.event.release.author.html_url }}",
"type": "${{ github.event.release.author.type }}",
"site_admin": ${{ github.event.release.author.site_admin }}
},
"tarball_url": "${{ github.event.release.tarball_url }}",
"zipball_url": "${{ github.event.release.zipball_url }}",
"assets": ${{ toJSON(github.event.release.assets) }}
},
"repository": {
"id": ${{ github.event.repository.id }},
"node_id": "${{ github.event.repository.node_id }}",
"name": "${{ github.event.repository.name }}",
"full_name": "${{ github.event.repository.full_name }}",
"private": ${{ github.event.repository.private }},
"owner": {
"login": "${{ github.event.repository.owner.login }}",
"id": ${{ github.event.repository.owner.id }},
"node_id": "${{ github.event.repository.owner.node_id }}",
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
"url": "${{ github.event.repository.owner.url }}",
"html_url": "${{ github.event.repository.owner.html_url }}",
"type": "${{ github.event.repository.owner.type }}",
"site_admin": ${{ github.event.repository.owner.site_admin }}
},
"html_url": "${{ github.event.repository.html_url }}",
"clone_url": "${{ github.event.repository.clone_url }}",
"git_url": "${{ github.event.repository.git_url }}",
"ssh_url": "${{ github.event.repository.ssh_url }}",
"url": "${{ github.event.repository.url }}",
"created_at": "${{ github.event.repository.created_at }}",
"updated_at": "${{ github.event.repository.updated_at }}",
"pushed_at": "${{ github.event.repository.pushed_at }}",
"default_branch": "${{ github.event.repository.default_branch }}",
"fork": ${{ github.event.repository.fork }}
},
"sender": {
"login": "${{ github.event.sender.login }}",
"id": ${{ github.event.sender.id }},
"node_id": "${{ github.event.sender.node_id }}",
"avatar_url": "${{ github.event.sender.avatar_url }}",
"url": "${{ github.event.sender.url }}",
"html_url": "${{ github.event.sender.html_url }}",
"type": "${{ github.event.sender.type }}",
"site_admin": ${{ github.event.sender.site_admin }}
}
}
EOF
)
# Generate HMAC-SHA256 signature
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
# Send webhook with required headers
curl -X POST "$WEBHOOK_URL" \
-H "Content-Type: application/json" \
-H "X-GitHub-Event: release" \
-H "X-GitHub-Delivery: $DELIVERY_ID" \
-H "X-GitHub-Hook-ID: $HOOK_ID" \
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
-d "$PAYLOAD" \
--fail --silent --show-error
echo "✅ Release webhook sent successfully"

View File

@@ -102,5 +102,4 @@ jobs:
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
prerelease: true
make_latest: false
draft: true

View File

@@ -28,7 +28,3 @@ jobs:
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit
- name: Run Execution Model Tests
run: |
python -m pytest tests/inference/test_execution.py

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "128"
default: "129"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "2"
default: "5"
# push:
# branches:
# - master
@@ -53,6 +53,8 @@ jobs:
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd

View File

@@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
- SD1.x, SD2.x,
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
@@ -65,13 +65,19 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -79,9 +85,10 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
@@ -92,12 +99,10 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@@ -106,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)
@@ -174,10 +179,6 @@ If you have trouble extracting it, right click the file -> properties -> unblock
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
## Jupyter Notebook
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
@@ -239,7 +240,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
#### Troubleshooting
@@ -272,6 +273,8 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
@@ -291,6 +294,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py`
#### Iluvatar Corex
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
# Running
```python main.py```

View File

@@ -29,18 +29,48 @@ def frontend_install_warning_message():
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
""".strip()
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
def get_installed_frontend_version():
"""Get the currently installed frontend package version."""
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version_str = get_installed_frontend_version()
frontend_version = parse_version(frontend_version_str)
with open(requirements_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
required_frontend_str = get_required_frontend_version()
required_frontend = parse_version(required_frontend_str)
if frontend_version < required_frontend:
app.logger.log_startup_warning(
f"""
@@ -168,6 +198,11 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def get_required_frontend_version(cls) -> str:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def default_frontend_path(cls) -> str:
try:

View File

@@ -130,10 +130,21 @@ class ModelFileManager:
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
@@ -144,7 +155,7 @@ class ModelFileManager:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)

View File

@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
path: str
size: int
modified: int
created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
"modified": os.path.getmtime(path),
"created": os.path.getctime(path)
}

View File

@@ -49,7 +49,8 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
@@ -144,6 +145,7 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -151,6 +153,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

View File

@@ -1,6 +1,7 @@
import torch
import math
import comfy.utils
import logging
class CONDRegular:
@@ -10,12 +11,15 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
def concat(self, others):
@@ -29,14 +33,14 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular):
@@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True
def concat(self, others):
@@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@@ -92,10 +99,10 @@ class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)

View File

@@ -28,6 +28,7 @@ import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
import comfy.model_base
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@@ -43,7 +44,6 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@@ -265,12 +265,12 @@ class ControlNet(ControlBase):
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):

View File

@@ -1,55 +1,10 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):

View File

@@ -0,0 +1,121 @@
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
# Codebase ref: https://github.com/scxue/SA-Solver
import math
from typing import Union, Callable
import torch
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
Integral of exp((1 + tau^2) * x) * x^p dx
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
with base case p=0 where integral equals product_terms[0].
where
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
Return coefficients used by the SA-Solver in data prediction mode.
Args:
s: Start time s.
t: End time t.
solver_order: Current order of the solver.
tau_t: Stochastic strength parameter in the SDE.
Returns:
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order1, shape (solver_order,).
"""
tau_mul = 1 + tau_t ** 2
h = t - s
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
# product_terms after factoring out exp((1 + tau^2) * t)
# Includes (1 + tau^2) factor from outside the integral
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
# Lower triangular recursive coefficient matrix
# Accumulates recursive coefficients based on p / (1 + tau^2)
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
log_factorial = (p + 1).lgamma()
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
if tau_t > 0:
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
return recursive_coeff_mat @ product_terms_factored
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
tau_mul = 1 + tau_t ** 2
h = lambda_t - lambda_s
alpha_t = sigma_next * lambda_t.exp()
if is_corrector_step:
# Simplified 1-step (order-2) corrector
b_1 = alpha_t * (0.5 * tau_mul * h)
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
else:
# Simplified 2-step predictor
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
return torch.stack([b_2, b_1])
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
The solver order corresponds to the number of input lambdas (half-logSNR points).
Args:
sigma_next: Sigma at end time t.
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
lambda_s: Lambda at start time s.
lambda_t: Lambda at end time t.
tau_t: Stochastic strength parameter in the SDE.
simple_order_2: Whether to enable the simple order-2 scheme.
is_corrector_step: Flag for corrector step in simple order-2 mode.
Returns:
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
"""
num_timesteps = curr_lambdas.shape[0]
if simple_order_2 and num_timesteps == 2:
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
# Compute coefficients by solving a linear system from Lagrange basis interpolation
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
# = sigma_t * exp(lambda_t) = alpha_t
# exp((1 + tau^2) * lambda_t) is extracted from the integral
alpha_t = sigma_next * lambda_t.exp()
return alpha_t * lagrange_integrals
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
"""Return a function that controls the stochasticity of SA-Solver.
When eta = 0, SA-Solver runs as ODE. The official approach uses
time t to determine the SDE interval, while here we use sigma instead.
See:
https://github.com/scxue/SA-Solver/blob/main/README.md
"""
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
if eta <= 0:
return 0.0 # ODE
if isinstance(sigma, torch.Tensor):
sigma = sigma.item()
return eta if start_sigma >= sigma >= end_sigma else 0.0
return tau_func

View File

@@ -1,4 +1,5 @@
import math
from functools import partial
from scipy import integrate
import torch
@@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
@@ -142,6 +144,33 @@ class BrownianTreeNoiseSampler:
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(model_sampling, comfy.model_sampling.CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -384,9 +413,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@@ -682,6 +715,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -693,38 +727,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
# Denoising step
x = denoised
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
alpha_s = sigmas[i] * lambda_s.exp()
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
lambda_s_1_ = sd.log().neg()
h_ = lambda_s_1_ - lambda_s
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
if eta > 0 and s_noise > 0:
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
lambda_t_ = sd.log().neg()
h_ = lambda_t_ - lambda_s
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
return x
@@ -753,6 +798,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
@@ -768,9 +814,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None
h_last = None
h = None
h, h_last = None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -781,26 +830,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -814,6 +866,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@@ -825,13 +881,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
# Denoising step
x = denoised
else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
@@ -840,20 +899,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d
x = x + (alpha_t * phi_2) * d
if eta:
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
@@ -863,6 +924,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -872,6 +934,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
@@ -1009,7 +1072,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
@@ -1027,6 +1092,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
@@ -1050,7 +1116,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
@@ -1090,6 +1158,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@torch.no_grad()
@@ -1140,39 +1209,22 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
"""Ancestral sampling with Euler method steps (CFG++)."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
uncond_denoised = None
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
@@ -1181,15 +1233,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0])
# Euler method
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
# DDIM stochastic sampling
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
sigma_down = alpha_t * sigma_down
# Euler method
x = alpha_t * denoised + sigma_down * d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Euler method steps (CFG++)."""
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
@@ -1346,6 +1416,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1372,31 +1443,32 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# Euler method
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt
else:
# Gradient estimation
if cfg_pp:
if i >= 1:
# Gradient estimation
d_bar = (ge_gamma - 1) * (d - old_d)
x = denoised + d * sigmas[i + 1] + d_bar * dt
else:
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
@@ -1404,12 +1476,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
def default_er_sde_noise_scaler(x):
return x * ((x ** 0.3).exp() + 10.0)
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
old_denoised = None
old_denoised_d = None
@@ -1420,41 +1498,45 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigmas[i] / er_lambda_s
alpha_t = sigmas[i + 1] / er_lambda_t
r_alpha = alpha_t / alpha_s
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
# Stage 2
s = torch.sum(1 / scaled_pos) * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
if stage_used >= 3:
# Stage 3
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
old_denoised_d = denoised_d
if s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1462,6 +1544,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
@@ -1469,80 +1556,206 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
s = t + r * h
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s = s.neg().exp()
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x
@torch.no_grad()
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
if tau_func is None:
# Use default interval for stochastic sampling
start_sigma = model_sampling.percent_to_sigma(0.2)
end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
max_used_order = max(predictor_order, corrector_order)
x_pred = x # x: current state, x_pred: predicted next state
h = 0.0
tau_t = 0.0
noise = 0.0
pred_list = []
# Lower order near the end to improve stability
lower_order_to_end = sigmas[-1].item() == 0
for i in trange(len(sigmas) - 1, disable=disable):
# Evaluation
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
pred_list.append(denoised)
pred_list = pred_list[-max_used_order:]
predictor_order_used = min(predictor_order, len(pred_list))
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
corrector_order_used = 0
else:
corrector_order_used = min(corrector_order, len(pred_list))
if lower_order_to_end:
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
# Corrector
if corrector_order_used == 0:
# Update by the predicted state
x = x_pred
else:
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i],
curr_lambdas,
lambdas[i - 1],
lambdas[i],
tau_t,
simple_order_2,
is_corrector_step=True,
)
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
if tau_t > 0 and s_noise > 0:
# The noise from the previous predictor step
x = x + noise
if use_pece:
# Evaluate the corrected state
denoised = model(x, sigmas[i] * s_in, **extra_args)
pred_list[-1] = denoised
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i + 1],
curr_lambdas,
lambdas[i],
lambdas[i + 1],
tau_t,
simple_order_2,
is_corrector_step=False,
)
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
h = lambdas[i + 1] - lambdas[i]
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)

View File

@@ -457,6 +457,82 @@ class Wan21(LatentFormat):
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@@ -254,13 +254,12 @@ class Chroma(nn.Module):
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
@@ -268,4 +267,4 @@ class Chroma(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]

View File

@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
def nonlinearity(x):
return x * torch.sigmoid(x)
# x * sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -70,11 +70,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
)
return result_B_S_HD
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):

View File

@@ -123,6 +123,8 @@ class ControlNetFlux(Flux):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
else:
y = y[:, :self.params.vec_in_dim]
# running on sequences img
img = self.img_in(img)

View File

@@ -118,7 +118,7 @@ class Modulation(nn.Module):
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
return torch.addcmul(m_add, tensor, m_mult)
else:
return tensor * m_mult
else:

View File

@@ -195,20 +195,50 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]

View File

@@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
)
self.per_channel_statistics = processor()

View File

@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
def __init__(self, sample: bool = False):
super().__init__()
self.sample = sample
@@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
return z, None
class AbstractAutoencoder(torch.nn.Module):

View File

@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -31,7 +31,7 @@ def dynamic_slice(
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
return x[slicing]
class AttnChunk(NamedTuple):

View File

@@ -0,0 +1,469 @@
# Original code: https://github.com/VectorSpaceLab/OmniGen2
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.act = nn.SiLU()
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class LuminaRMSNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x) * (1 + emb)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
self.caption_embedder = nn.Sequential(
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
class Attention(nn.Module):
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = query.view(batch_size, -1, self.heads, self.dim_head)
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
class OmniGen2TransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.attn = Attention(
query_dim=dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
dtype=dtype, device=device, operations=operations,
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
dtype=dtype, device=device, operations=operations
)
if modulation:
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
else:
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
p = self.patch_size
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
max_seq_len = max(seq_lengths)
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
pe_shift = cap_seq_len
pe_shift_len = cap_seq_len
if ref_img_sizes[i] is not None:
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
H, W = ref_img_size
ref_H_tokens, ref_W_tokens = H // p, W // p
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
pe_shift += max(ref_H_tokens, ref_W_tokens)
pe_shift_len += ref_img_len
H, W = img_sizes[i]
H_tokens, W_tokens = H // p, W // p
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = encoder_seq_len
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
ref_img_freqs_cis_shape = list(freqs_cis.shape)
ref_img_freqs_cis_shape[1] = max_ref_img_len
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
class OmniGen2Transformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.layers = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.patch_size
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
flat_ref_img_hidden_states = None
if ref_image_hidden_states is not None:
imgs = []
for ref_img in ref_image_hidden_states:
B, C, H, W = ref_img.size()
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
imgs.append(ref_img)
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
img = hidden_states
B, C, H, W = img.size()
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
return (
flat_hidden_states, flat_ref_img_hidden_states,
None, None,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if ref_image_hidden_states is not None:
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@@ -1,256 +1,256 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@@ -0,0 +1,400 @@
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
self.approximate = approximate
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
inner_dim=None,
bias: bool = True,
dtype=None, device=None, operations=None
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([])
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim,
dtype=dtype,
device=device,
operations=operations
)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
return timesteps_emb
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
dim_head: int = 64,
heads: int = 8,
dropout: float = 0.0,
bias: bool = False,
eps: float = 1e-5,
out_bias: bool = True,
out_dim: int = None,
out_context_dim: int = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim
self.heads = heads
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.dropout = dropout
# Q/K normalization
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
# Image stream projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text stream projections
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Output projections
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
nn.Dropout(dropout)
])
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
img_attn_output = self.to_out[0](img_attn_output)
img_attn_output = self.to_out[1](img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class LastLayer(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine=False,
eps=1e-6,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class QwenImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: Optional[int] = 16,
num_layers: int = 60,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
dtype=dtype,
device=device,
operations=operations
)
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
QwenImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_layers)
])
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
def pos_embeds(self, x, context):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
**kwargs
):
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = {
}
def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
class WanAttentionBlock(nn.Module):
def __init__(self,
@@ -202,20 +211,23 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0],
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
freqs)
x = x + y * e[2]
x = x + y * repeat_e(e[2], x)
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * repeat_e(e[5], x)
return x
@@ -325,8 +337,12 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
if e.ndim < 3:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
return x
@@ -506,8 +522,9 @@ class WanModel(torch.nn.Module):
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)
@@ -752,8 +769,7 @@ class CameraWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
x = x + self.control_adapter(camera_conditions).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)

View File

@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
del cache_x
x = F.pad(x, padding)
return super().forward(x)
@@ -52,15 +57,6 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
@@ -73,11 +69,11 @@ class Resample(nn.Module):
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -157,29 +153,6 @@ class Resample(nn.Module):
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
@@ -198,7 +171,7 @@ class ResidualBlock(nn.Module):
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@@ -210,12 +183,12 @@ class ResidualBlock(nn.Module):
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
return x + self.shortcut(old_x)
class AttentionBlock(nn.Module):
@@ -494,12 +467,6 @@ class WanVAE(nn.Module):
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x):
self.clear_cache()
## cache
@@ -545,18 +512,6 @@ class WanVAE(nn.Module):
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]

726
comfy/ldm/wan/vae2_2.py Normal file
View File

@@ -0,0 +1,726 @@
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .vae import AttentionBlock, CausalConv3d, RMS_norm
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
# ops.Conv2d(dim, dim//2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] != "Rep"):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] == "Rep"):
cache_x = torch.cat(
[
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + self.shortcut(old_x)
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_upsample=False,
up_flag=False):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i]
if i < len(temperal_downsample) else False)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(
temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1,
))
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(
self,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dec_dim,
z_dim,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_upsample,
dropout,
)
def encode(self, x):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -41,6 +41,8 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.model_management
import comfy.patcher_extension
@@ -105,10 +107,12 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = comfy.model_management.cast_to_device(extra, device, dtype)
else:
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
@@ -159,7 +163,7 @@ class BaseModel(torch.nn.Module):
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
@@ -168,20 +172,21 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype)
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra
@@ -397,7 +402,7 @@ class SD21UNCLIP(BaseModel):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
return torch.zeros((1, self.adm_channels), device=device)
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
@@ -611,9 +616,11 @@ class IP2P:
if image is None:
image = torch.zeros_like(noise)
else:
image = image.to(device=device)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
@@ -692,7 +699,7 @@ class StableCascade_B(BaseModel):
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior)
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out
@@ -815,6 +822,7 @@ class PixArt(BaseModel):
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
self.memory_usage_factor_conds = ("ref_latents",)
def concat_cond(self, **kwargs):
try:
@@ -875,8 +883,23 @@ class Flux(BaseModel):
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@@ -1014,9 +1037,32 @@ class CosmosPredict2(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
if denoise_mask.ndim <= 4:
return timestep
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
@@ -1057,8 +1103,9 @@ class WAN21(BaseModel):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video or extra_channels == image.shape[1]:
return image
if extra_channels != image.shape[1] + 4:
if not self.image_to_video or extra_channels == image.shape[1]:
return image
if image.shape[1] > (extra_channels - 4):
image = image[:, :(extra_channels - 4)]
@@ -1117,10 +1164,10 @@ class WAN21_Vace(WAN21):
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
vace_frames_out.append(vf)
vace_frames = torch.stack(vace_frames_out, dim=1)
@@ -1142,6 +1189,31 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
@@ -1207,3 +1279,44 @@ class ACEStep(BaseModel):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -346,7 +346,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
@@ -441,11 +443,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17: # img to video
if dit_config["model_channels"] == 2048:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["model_channels"] == 5120:
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
@@ -454,6 +461,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
dit_config["axes_dim_rope"] = [40, 40, 40]
dit_config["axes_lens"] = [1024, 1664, 1664]
dit_config["ffn_dim_multiplier"] = None
dit_config["hidden_size"] = 2520
dit_config["in_channels"] = 16
dit_config["multiple_of"] = 256
dit_config["norm_eps"] = 1e-05
dit_config["num_attention_heads"] = 21
dit_config["num_kv_heads"] = 7
dit_config["num_layers"] = 32
dit_config["num_refiner_layers"] = 2
dit_config["out_channels"] = None
dit_config["patch_size"] = 2
dit_config["text_feat_dim"] = 2048
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -840,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)

View File

@@ -101,7 +101,7 @@ if args.directml is not None:
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # noqa: F401
_ = torch.xpu.device_count()
xpu_available = xpu_available or torch.xpu.is_available()
except:
@@ -128,6 +128,11 @@ try:
except:
mlu_available = False
try:
ixuca_available = hasattr(torch, "corex")
except:
ixuca_available = False
if args.cpu:
cpu_state = CPUState.CPU
@@ -151,6 +156,12 @@ def is_mlu():
return True
return False
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -186,8 +197,9 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total = mem_total_xpu
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -288,7 +300,7 @@ try:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
@@ -307,7 +319,10 @@ try:
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 8):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
@@ -377,6 +392,8 @@ def get_torch_device_name(device):
except:
allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
elif device.type == "xpu":
return "{} {}".format(device, torch.xpu.get_device_name(device))
else:
return "{}".format(device.type)
elif is_intel_xpu():
@@ -512,6 +529,8 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@@ -876,6 +895,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
@@ -929,7 +949,7 @@ def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
if is_intel_xpu():
return False
return True
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
@@ -968,6 +988,8 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
@@ -979,6 +1001,15 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
@@ -986,6 +1017,8 @@ def sync_stream(device, stream):
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1027,6 +1060,8 @@ def xformers_enabled():
return False
if is_mlu():
return False
if is_ixuca():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -1062,6 +1097,8 @@ def pytorch_attention_flash_attention():
return True
if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
return False
def force_upcast_attention_dtype():
@@ -1092,8 +1129,8 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
@@ -1142,6 +1179,9 @@ def is_device_cpu(device):
def is_device_mps(device):
return is_device_type(device, 'mps')
def is_device_xpu(device):
return is_device_type(device, 'xpu')
def is_device_cuda(device):
return is_device_type(device, 'cuda')
@@ -1173,7 +1213,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
return True
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu():
return True
@@ -1181,6 +1224,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True
if is_ixuca():
return True
if torch.version.hip:
return True
@@ -1236,11 +1282,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
return True
if torch_version_numeric < (2, 6):
return True
else:
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
if is_ascend_npu():
return True
if is_ixuca():
return True
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
@@ -1290,6 +1342,13 @@ def supports_fp8_compute(device=None):
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -379,6 +379,9 @@ class ModelPatcher:
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function

View File

@@ -336,9 +336,12 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

View File

@@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)
@@ -373,7 +373,11 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
uncond_ = uncond
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
if "sampler_calc_cond_batch_function" in model_options:
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = model_options["sampler_calc_cond_batch_function"](args)
else:
out = calc_cond_batch(model, conds, x, timestep, model_options)
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
@@ -716,7 +720,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@@ -1039,13 +1043,13 @@ class SchedulerHandler(NamedTuple):
use_ms: bool = True
SCHEDULER_HANDLERS = {
"normal": SchedulerHandler(normal_scheduler),
"simple": SchedulerHandler(simple_scheduler),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"simple": SchedulerHandler(simple_scheduler),
"ddim_uniform": SchedulerHandler(ddim_scheduler),
"beta": SchedulerHandler(beta_scheduler),
"normal": SchedulerHandler(normal_scheduler),
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
}

View File

@@ -14,10 +14,12 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
import os
import comfy.utils
@@ -44,6 +46,8 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.model_patcher
import comfy.lora
@@ -418,17 +422,30 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.latent_channels = 48
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
@@ -754,6 +771,8 @@ class CLIPType(Enum):
HIDREAM = 14
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -773,6 +792,8 @@ class TEModel(Enum):
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -793,6 +814,12 @@ def detect_te_model(sd):
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -894,6 +921,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -969,6 +1002,12 @@ def load_gligen(ckpt_path):
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
if 'lora' in filename.lower():
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
return ""
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@@ -997,7 +1036,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
@@ -1160,7 +1199,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in unet: {}".format(left_over))
logging.info("left over keys in diffusion model: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
@@ -1168,8 +1207,8 @@ def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
return model
def load_unet(unet_path, dtype=None):

View File

@@ -462,7 +462,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
@@ -482,7 +482,8 @@ class SDTokenizer:
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if has_end_token:
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token

View File

@@ -18,7 +18,7 @@
"single_word": false
},
"errors": "replace",
"model_max_length": 77,
"model_max_length": 8192,
"name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",

View File

@@ -18,6 +18,8 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
from . import supported_models_base
from . import latent_formats
@@ -1058,6 +1060,19 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"out_dim": 48,
}
latent_format = latent_formats.Wan22
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1181,6 +1196,70 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
class Omnigen2(supported_models_base.BASE):
unet_config = {
"image_model": "omnigen2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Omnigen2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.15,
}
memory_usage_factor = 1.8 #TODO
unet_extra_config = {}
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.QwenImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@@ -24,6 +24,41 @@ class Llama2Config:
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 11008
num_hidden_layers: int = 36
num_attention_heads: int = 16
num_key_value_heads: int = 2
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
hidden_size: int = 3584
intermediate_size: int = 18944
num_hidden_layers: int = 28
num_attention_heads: int = 28
num_key_value_heads: int = 4
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
@@ -40,6 +75,7 @@ class Gemma2_2B_Config:
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -98,9 +134,9 @@ class Attention(nn.Module):
self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
@@ -320,6 +356,23 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):

View File

@@ -0,0 +1,44 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_

View File

@@ -1,42 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<|img|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151666": {
"content": "<|endofimg|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151667": {
"content": "<|meta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151668": {
"content": "<|endofmeta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"extra_special_tokens": {},
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"processor_class": "Qwen2_5_VLProcessor",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,71 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
import torch
import numbers
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
out = out[:, template_end:]
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask") # attention mask is useless if no masked elements
return out, pooled, extra
def te(dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
)
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
return values.contiguous()
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
q = self.q(x)

View File

@@ -31,6 +31,7 @@ from einops import rearrange
from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@@ -58,7 +59,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
except Exception as e:
@@ -77,6 +81,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
@@ -693,6 +698,26 @@ def resize_to_batch_size(tensor, batch_size):
return output
def resize_list_to_batch_size(l, batch_size):
in_batch_size = len(l)
if in_batch_size == batch_size or in_batch_size == 0:
return l
if batch_size <= 1:
return l[:batch_size]
output = []
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output.append(l[min(round(i * scale), in_batch_size - 1)])
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
return output
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:

View File

@@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
adapter_maps: dict[str, type[WeightAdapterBase]] = {
"LoRA": LoRAAdapter,
"LoHa": LoHaAdapter,
"LoKr": LoKrAdapter,
"OFT": OFTAdapter,
## We disable not implemented algo for now
# "GLoRA": GLoRAAdapter,
# "BOFT": BOFTAdapter,
}
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
"adapters",
"adapter_maps",
] + [a.__name__ for a in adapters]

View File

@@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
"""
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
"""
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@@ -3,7 +3,120 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
class LohaDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
# Create trainable parameters
self.hada_w1_a = torch.nn.Parameter(w1a)
self.hada_w1_b = torch.nn.Parameter(w1b)
self.hada_w2_a = torch.nn.Parameter(w2a)
self.hada_w2_b = torch.nn.Parameter(w2b)
self.use_tucker = False
if t1 is not None and t2 is not None:
self.use_tucker = True
self.hada_t1 = torch.nn.Parameter(t1)
self.hada_t2 = torch.nn.Parameter(t2)
else:
# Keep the attributes for consistent access
self.hada_t1 = None
self.hada_t2 = None
# Store rank and non-trainable alpha
self.rank = w1b.shape[0]
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoHaAdapter(WeightAdapterBase):
@@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
def to_train(self):
return LohaDiff(self.weights)
@classmethod
def load(
cls,

View File

@@ -3,7 +3,77 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
self.w1_rebuild = True
self.ranka = rank_a
if lokr_w2_a is not None:
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
if lokr_t2 is not None:
self.use_tucker = True
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
self.w2_rebuild = True
self.rankb = rank_b
if lokr_w1 is not None:
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
self.w1_rebuild = False
if lokr_w2 is not None:
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
self.w2_rebuild = False
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
@property
def w1(self):
if self.w1_rebuild:
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
else:
return self.lokr_w1
@property
def w2(self):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return w2 * (self.alpha / self.rankb)
else:
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
return w + diff.reshape(w.shape).to(w)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoKrAdapter(WeightAdapterBase):
@@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
@classmethod
def load(
cls,

View File

@@ -3,7 +3,58 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
self.oft_blocks = torch.nn.Parameter(blocks)
if rescale is not None:
self.rescale = torch.nn.Parameter(rescale)
self.rescaled = True
else:
self.rescaled = False
self.block_num, self.block_size, _ = blocks.shape
self.constraint = float(alpha)
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
I = torch.eye(self.block_size, device=self.oft_blocks.device)
## generate r
# for Q = -Q^T
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
normed_q = q
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
## Apply chunked matmul on weight
_, *shape = w.shape
org_weight = w.to(dtype=r.dtype)
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
# Init R=0, so add I on it to ensure the output of step0 is original model output
weight = torch.einsum(
"k n m, k n ... -> k m ...",
r,
org_weight,
).flatten(0, 1)
if self.rescaled:
weight = self.rescale * weight
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class OFTAdapter(WeightAdapterBase):
@@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
return OFTDiff(
(block, None, alpha, None)
)
def to_train(self):
return OFTDiff(self.weights)
@classmethod
def load(
cls,
@@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
blocks = v[0]
rescale = v[1]
alpha = v[2]
if alpha is None:
alpha = 0
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)

View File

@@ -0,0 +1,69 @@
"""
Feature flags module for ComfyUI WebSocket protocol negotiation.
This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
}
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
) -> Any:
"""
Get a feature flag value for a specific connection.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
default: Default value if feature not found
Returns:
Feature value or default if not found
"""
if sid not in sockets_metadata:
return default
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
"""
Check if a connection supports a specific feature.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
Returns:
Boolean indicating if feature is supported
"""
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.
Returns:
Dictionary of server feature flags
"""
return SERVER_FEATURE_FLAGS.copy()

View File

@@ -0,0 +1,86 @@
#!/usr/bin/env python3
"""
Script to generate .pyi stub files for the synchronous API wrappers.
This allows generating stubs without running the full ComfyUI application.
"""
import os
import sys
import logging
import importlib
# Add ComfyUI to path so we can import modules
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
from comfy_api.version_list import supported_versions
def generate_stubs_for_module(module_name: str) -> None:
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
try:
# Import the module
module = importlib.import_module(module_name)
# Check if module has ComfyAPISync (the sync wrapper)
if hasattr(module, "ComfyAPISync"):
# Module already has a sync class
api_class = getattr(module, "ComfyAPI", None)
sync_class = getattr(module, "ComfyAPISync")
if api_class:
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
)
elif hasattr(module, "ComfyAPI"):
# Module only has async API, need to create sync wrapper first
from comfy_api.internal.async_to_sync import create_sync_class
api_class = getattr(module, "ComfyAPI")
sync_class = create_sync_class(api_class)
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
)
except Exception as e:
logging.error(f"Failed to generate stub for {module_name}: {e}")
import traceback
traceback.print_exc()
def main():
"""Main function to generate all API stub files."""
logging.basicConfig(level=logging.INFO)
logging.info("Starting stub generation...")
# Dynamically get module names from supported_versions
api_modules = []
for api_class in supported_versions:
# Extract module name from the class
module_name = api_class.__module__
if module_name not in api_modules:
api_modules.append(module_name)
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
# Generate stubs for each module
for module_name in api_modules:
generate_stubs_for_module(module_name)
logging.info("Stub generation complete!")
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,16 @@
from .basic_types import ImageInput, AudioInput
from .video_types import VideoInput
# This file only exists for backwards compatibility.
from comfy_api.latest._input import (
ImageInput,
AudioInput,
MaskInput,
LatentInput,
VideoInput,
)
__all__ = [
"ImageInput",
"AudioInput",
"MaskInput",
"LatentInput",
"VideoInput",
]

View File

@@ -1,20 +1,14 @@
import torch
from typing import TypedDict
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int
# This file only exists for backwards compatibility.
from comfy_api.latest._input.basic_types import (
ImageInput,
AudioInput,
MaskInput,
LatentInput,
)
__all__ = [
"ImageInput",
"AudioInput",
"MaskInput",
"LatentInput",
]

View File

@@ -1,55 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
# This file only exists for backwards compatibility.
from comfy_api.latest._input.video_types import VideoInput
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
__all__ = [
"VideoInput",
]

View File

@@ -1,7 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents
# This file only exists for backwards compatibility.
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@@ -1,303 +1,2 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.input import AudioInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.input import VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)
# This file only exists for backwards compatibility.
from comfy_api.latest._input_impl.video_types import * # noqa: F403

View File

@@ -0,0 +1,150 @@
# Internal infrastructure for ComfyAPI
from .api_registry import (
ComfyAPIBase as ComfyAPIBase,
ComfyAPIWithVersion as ComfyAPIWithVersion,
register_versions as register_versions,
get_all_versions as get_all_versions,
)
import asyncio
from dataclasses import asdict
from typing import Callable, Optional
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
"""Return the *callable* override of `name` visible on `cls`, or None if every
implementation up to (and including) `base` is the placeholder defined on `base`.
If base is not provided, it will assume cls has a GET_BASE_CLASS
"""
if base is None:
if not hasattr(cls, "GET_BASE_CLASS"):
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
base = cls.GET_BASE_CLASS()
base_attr = getattr(base, name, None)
if base_attr is None:
return None
base_func = base_attr.__func__
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
if c is base: # reached the placeholder we're done
break
if name in c.__dict__: # first class that *defines* the attr
func = getattr(c, name).__func__
if func is not base_func: # real override
return getattr(cls, name) # bound to *cls*
return None
class _ComfyNodeInternal:
"""Class that all V3-based APIs inherit from for ComfyNode.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
@classmethod
def GET_NODE_INFO_V1(cls):
...
class _NodeOutputInternal:
"""Class that all V3-based APIs inherit from for NodeOutput.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
...
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
def is_class(obj):
'''
Returns True if is a class type.
Returns False if is a class instance.
'''
return isinstance(obj, type)
def copy_class(cls: type) -> type:
'''
Copy a class and its attributes.
'''
if cls is None:
return None
cls_dict = {
k: v for k, v in cls.__dict__.items()
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
}
# new class
new_cls = type(
cls.__name__,
(cls,),
cls_dict
)
# metadata preservation
new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__
return new_cls
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
# NOTE: this was ai generated and validated by hand
def shallow_clone_class(cls, new_name=None):
'''
Shallow clone a class while preserving super() functionality.
'''
new_name = new_name or f"{cls.__name__}Clone"
# Include the original class in the bases to maintain proper inheritance
new_bases = (cls,) + cls.__bases__
return type(new_name, new_bases, dict(cls.__dict__))
# NOTE: this was ai generated and validated by hand
def lock_class(cls):
'''
Lock a class so that its top-levelattributes cannot be modified.
'''
# Locked instance __setattr__
def locked_instance_setattr(self, name, value):
raise AttributeError(
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
)
# Locked metaclass
class LockedMeta(type(cls)):
def __setattr__(cls_, name, value):
raise AttributeError(
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
)
# Rebuild class with locked behavior
locked_dict = dict(cls.__dict__)
locked_dict['__setattr__'] = locked_instance_setattr
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
def make_locked_method_func(type_obj, func, class_clone):
"""
Returns a function that, when called with **inputs, will execute:
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
Supports both synchronous and asynchronous methods.
"""
locked_class = lock_class(class_clone)
method = getattr(type_obj, func).__func__
# Check if the original method is async
if asyncio.iscoroutinefunction(method):
async def wrapped_async_func(**inputs):
return await method(locked_class, **inputs)
return wrapped_async_func
else:
def wrapped_func(**inputs):
return method(locked_class, **inputs)
return wrapped_func

View File

@@ -0,0 +1,39 @@
from typing import Type, List, NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
class ComfyAPIBase(ProxiedSingleton):
def __init__(self):
pass
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: Type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
"""
Parses a version string into a packaging_version.Version object.
Raises ValueError if the version string is invalid.
"""
if version_str == "latest":
return packaging_version.parse("9999999.9999999.9999999")
return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""
return registered_versions

View File

@@ -0,0 +1,987 @@
import asyncio
import concurrent.futures
import contextvars
import functools
import inspect
import logging
import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args
class TypeTracker:
"""Tracks types discovered during stub generation for automatic import generation."""
def __init__(self):
self.discovered_types = {} # type_name -> (module, qualname)
self.builtin_types = {
"Any",
"Dict",
"List",
"Optional",
"Tuple",
"Union",
"Set",
"Sequence",
"cast",
"NamedTuple",
"str",
"int",
"float",
"bool",
"None",
"bytes",
"object",
"type",
"dict",
"list",
"tuple",
"set",
}
self.already_imported = (
set()
) # Track types already imported to avoid duplicates
def track_type(self, annotation):
"""Track a type annotation and record its module/import info."""
if annotation is None or annotation is type(None):
return
# Skip builtins and typing module types we already import
type_name = getattr(annotation, "__name__", None)
if type_name and (
type_name in self.builtin_types or type_name in self.already_imported
):
return
# Get module and qualname
module = getattr(annotation, "__module__", None)
qualname = getattr(annotation, "__qualname__", type_name or "")
# Skip types from typing module (they're already imported)
if module == "typing":
return
# Skip UnionType and GenericAlias from types module as they're handled specially
if module == "types" and type_name in ("UnionType", "GenericAlias"):
return
if module and module not in ["builtins", "__main__"]:
# Store the type info
if type_name:
self.discovered_types[type_name] = (module, qualname)
def get_imports(self, main_module_name: str) -> list[str]:
"""Generate import statements for all discovered types."""
imports = []
imports_by_module = {}
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
# Skip types from the main module (they're already imported)
if main_module_name and module == main_module_name:
continue
if module not in imports_by_module:
imports_by_module[module] = []
if type_name not in imports_by_module[module]: # Avoid duplicates
imports_by_module[module].append(type_name)
# Generate import statements
for module, types in sorted(imports_by_module.items()):
if len(types) == 1:
imports.append(f"from {module} import {types[0]}")
else:
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
return imports
class AsyncToSyncConverter:
"""
Provides utilities to convert async classes to sync classes with proper type hints.
"""
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
_thread_pool_lock = threading.Lock()
_thread_pool_initialized = False
@classmethod
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
"""Get or create the shared thread pool with proper thread-safe initialization."""
# Fast path - check if already initialized without acquiring lock
if cls._thread_pool_initialized:
assert cls._thread_pool is not None, "Thread pool should be initialized"
return cls._thread_pool
# Slow path - acquire lock and create pool if needed
with cls._thread_pool_lock:
if not cls._thread_pool_initialized:
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix="async_to_sync_"
)
cls._thread_pool_initialized = True
# This should never be None at this point, but add assertion for type checker
assert cls._thread_pool is not None
return cls._thread_pool
@classmethod
def run_async_in_thread(cls, coro_func, *args, **kwargs):
"""
Run an async function in a separate thread from the thread pool.
Blocks until the async function completes.
Properly propagates contextvars between threads and manages event loops.
"""
# Capture current context - this includes all context variables
context = contextvars.copy_context()
# Store the result and any exception that occurs
result_container: dict = {"result": None, "exception": None}
# Function that runs in the thread pool
def run_in_thread():
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Create the coroutine within the context
async def run_with_context():
# The coroutine function might access context variables
return await coro_func(*args, **kwargs)
# Run the coroutine with the captured context
# This ensures all context variables are available in the async function
result = context.run(loop.run_until_complete, run_with_context())
result_container["result"] = result
except Exception as e:
# Store the exception to re-raise in the calling thread
result_container["exception"] = e
finally:
# Ensure event loop is properly closed to prevent warnings
try:
# Cancel any remaining tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Run the loop briefly to handle cancellations
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass # Ignore errors during cleanup
# Close the event loop
loop.close()
# Clear the event loop from the thread
asyncio.set_event_loop(None)
# Submit to thread pool and wait for result
thread_pool = cls.get_thread_pool()
future = thread_pool.submit(run_in_thread)
future.result() # Wait for completion
# Re-raise any exception that occurred in the thread
if result_container["exception"] is not None:
raise result_container["exception"]
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a new class with synchronous versions of all async methods.
Args:
async_class: The async class to convert
thread_pool_size: Size of thread pool to use
Returns:
A new class with sync versions of all async methods
"""
sync_class_name = "ComfyAPISyncStub"
cls.get_thread_pool(thread_pool_size)
# Create a proper class with docstrings and proper base classes
sync_class_dict = {
"__doc__": async_class.__doc__,
"__module__": async_class.__module__,
"__qualname__": sync_class_name,
"__orig_class__": async_class, # Store original class for typing references
}
# Create __init__ method
def __init__(self, *args, **kwargs):
self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items():
if hasattr(self._async_instance, attr_name):
# Attribute exists on the instance
attr = getattr(self._async_instance, attr_name)
# Check if this attribute needs a sync wrapper
if hasattr(attr, "__class__"):
from comfy_api.internal.singleton import ProxiedSingleton
if isinstance(attr, ProxiedSingleton):
# Create a sync version of this attribute
try:
sync_attr_class = cls.create_sync_class(attr.__class__)
# Create instance of the sync wrapper with the async instance
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = attr
setattr(self, attr_name, sync_attr)
except Exception:
# If we can't create a sync version, keep the original
setattr(self, attr_name, attr)
else:
# Not async, just copy the reference
setattr(self, attr_name, attr)
else:
# Attribute doesn't exist, but is annotated - create it
# This handles cases like execution: Execution
if isinstance(attr_type, type):
# Check if the type is defined as an inner class
if hasattr(async_class, attr_type.__name__):
inner_class = getattr(async_class, attr_type.__name__)
from comfy_api.internal.singleton import ProxiedSingleton
# Create an instance of the inner class
try:
# For ProxiedSingleton classes, get or create the singleton instance
if issubclass(inner_class, ProxiedSingleton):
async_instance = inner_class.get_instance()
else:
async_instance = inner_class()
# Create sync wrapper
sync_attr_class = cls.create_sync_class(inner_class)
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = async_instance
setattr(self, attr_name, sync_attr)
# Also set on the async instance for consistency
setattr(self._async_instance, attr_name, async_instance)
except Exception as e:
logging.warning(
f"Failed to create instance for {attr_name}: {e}"
)
# Handle other instance attributes that might not be annotated
for name, attr in inspect.getmembers(self._async_instance):
if name.startswith("_") or hasattr(self, name):
continue
# If attribute is an instance of a class, and that class is defined in the original class
# we need to check if it needs a sync wrapper
if isinstance(attr, object) and not isinstance(
attr, (str, int, float, bool, list, dict, tuple)
):
from comfy_api.internal.singleton import ProxiedSingleton
if isinstance(attr, ProxiedSingleton):
# Create a sync version of this nested class
try:
sync_attr_class = cls.create_sync_class(attr.__class__)
# Create instance of the sync wrapper with the async instance
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = attr
setattr(self, name, sync_attr)
except Exception:
# If we can't create a sync version, keep the original
setattr(self, name, attr)
sync_class_dict["__init__"] = __init__
# Process methods from the async class
for name, method in inspect.getmembers(
async_class, predicate=inspect.isfunction
):
if name.startswith("_"):
continue
# Extract the actual return type from a coroutine
if inspect.iscoroutinefunction(method):
# Create sync version of async method with proper signature
@functools.wraps(method)
def sync_method(self, *args, _method_name=name, **kwargs):
async_method = getattr(self._async_instance, _method_name)
return AsyncToSyncConverter.run_async_in_thread(
async_method, *args, **kwargs
)
# Add to the class dict
sync_class_dict[name] = sync_method
else:
# For regular methods, create a proxy method
@functools.wraps(method)
def proxy_method(self, *args, _method_name=name, **kwargs):
method = getattr(self._async_instance, _method_name)
return method(*args, **kwargs)
# Add to the class dict
sync_class_dict[name] = proxy_method
# Handle property access
for name, prop in inspect.getmembers(
async_class, lambda x: isinstance(x, property)
):
def make_property(name, prop_obj):
def getter(self):
value = getattr(self._async_instance, name)
if inspect.iscoroutinefunction(value):
def sync_fn(*args, **kwargs):
return AsyncToSyncConverter.run_async_in_thread(
value, *args, **kwargs
)
return sync_fn
return value
def setter(self, value):
setattr(self._async_instance, name, value)
return property(getter, setter if prop_obj.fset else None)
sync_class_dict[name] = make_property(name, prop)
# Create the class
sync_class = type(sync_class_name, (object,), sync_class_dict)
return sync_class
@classmethod
def _format_type_annotation(
cls, annotation, type_tracker: Optional[TypeTracker] = None
) -> str:
"""Convert a type annotation to its string representation for stub files."""
if (
annotation is inspect.Parameter.empty
or annotation is inspect.Signature.empty
):
return "Any"
# Handle None type
if annotation is type(None):
return "None"
# Track the type if we have a tracker
if type_tracker:
type_tracker.track_type(annotation)
# Try using typing.get_origin/get_args for Python 3.8+
try:
origin = get_origin(annotation)
args = get_args(annotation)
if origin is not None:
# Track the origin type
if type_tracker:
type_tracker.track_type(origin)
# Get the origin name
origin_name = getattr(origin, "__name__", str(origin))
if "." in origin_name:
origin_name = origin_name.split(".")[-1]
# Special handling for types.UnionType (Python 3.10+ pipe operator)
# Convert to old-style Union for compatibility
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
origin_name = "Union"
# Format arguments recursively
if args:
formatted_args = []
for arg in args:
# Track each type in the union
if type_tracker:
type_tracker.track_type(arg)
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
return f"{origin_name}[{', '.join(formatted_args)}]"
else:
return origin_name
except (AttributeError, TypeError):
# Fallback for older Python versions or non-generic types
pass
# Handle generic types the old way for compatibility
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
origin = annotation.__origin__
origin_name = (
origin.__name__
if hasattr(origin, "__name__")
else str(origin).split("'")[1]
)
# Format each type argument
args = []
for arg in annotation.__args__:
args.append(cls._format_type_annotation(arg, type_tracker))
return f"{origin_name}[{', '.join(args)}]"
# Handle regular types with __name__
if hasattr(annotation, "__name__"):
return annotation.__name__
# Handle special module types (like types from typing module)
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
# For types like typing.Literal, typing.TypedDict, etc.
return annotation.__qualname__
# Last resort: string conversion with cleanup
type_str = str(annotation)
# Clean up common patterns more robustly
if type_str.startswith("<class '") and type_str.endswith("'>"):
type_str = type_str[8:-2] # Remove "<class '" and "'>"
# Remove module prefixes for common modules
for prefix in ["typing.", "builtins.", "types."]:
if type_str.startswith(prefix):
type_str = type_str[len(prefix) :]
# Handle special cases
if type_str in ("_empty", "inspect._empty"):
return "None"
# Fix NoneType (this should rarely be needed now)
if type_str == "NoneType":
return "None"
return type_str
@classmethod
def _extract_coroutine_return_type(cls, annotation):
"""Extract the actual return type from a Coroutine annotation."""
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
return annotation.__args__[2]
return annotation
@classmethod
def _format_parameter_default(cls, default_value) -> str:
"""Format a parameter's default value for stub files."""
if default_value is inspect.Parameter.empty:
return ""
elif default_value is None:
return " = None"
elif isinstance(default_value, bool):
return f" = {default_value}"
elif default_value == {}:
return " = {}"
elif default_value == []:
return " = []"
else:
return f" = {default_value}"
@classmethod
def _format_method_parameters(
cls,
sig: inspect.Signature,
skip_self: bool = True,
type_hints: Optional[dict] = None,
type_tracker: Optional[TypeTracker] = None,
) -> str:
"""Format method parameters for stub files."""
params = []
if type_hints is None:
type_hints = {}
for i, (param_name, param) in enumerate(sig.parameters.items()):
if i == 0 and param_name == "self" and skip_self:
params.append("self")
else:
# Get type annotation from type hints if available, otherwise from signature
annotation = type_hints.get(param_name, param.annotation)
type_str = cls._format_type_annotation(annotation, type_tracker)
# Get default value
default_str = cls._format_parameter_default(param.default)
# Combine parameter parts
if annotation is inspect.Parameter.empty:
params.append(f"{param_name}: Any{default_str}")
else:
params.append(f"{param_name}: {type_str}{default_str}")
return ", ".join(params)
@classmethod
def _generate_method_signature(
cls,
method_name: str,
method,
is_async: bool = False,
type_tracker: Optional[TypeTracker] = None,
) -> str:
"""Generate a complete method signature for stub files."""
sig = inspect.signature(method)
# Try to get evaluated type hints to resolve string annotations
try:
from typing import get_type_hints
type_hints = get_type_hints(method)
except Exception:
# Fallback to empty dict if we can't get type hints
type_hints = {}
# For async methods, extract the actual return type
return_annotation = type_hints.get('return', sig.return_annotation)
if is_async and inspect.iscoroutinefunction(method):
return_annotation = cls._extract_coroutine_return_type(return_annotation)
# Format parameters with type hints
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
# Format return type
return_type = cls._format_type_annotation(return_annotation, type_tracker)
if return_annotation is inspect.Signature.empty:
return_type = "None"
return f"def {method_name}({params_str}) -> {return_type}: ..."
@classmethod
def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
# Add standard typing imports
imports.append(
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
)
# Add imports from the original module
if async_class.__module__ != "builtins":
module = inspect.getmodule(async_class)
additional_types = []
if module:
# Check if module has __all__ defined
module_all = getattr(module, "__all__", None)
for name, obj in sorted(inspect.getmembers(module)):
if isinstance(obj, type):
# Skip if __all__ is defined and this name isn't in it
# unless it's already been tracked as used in type annotations
if module_all is not None and name not in module_all:
# Check if this type was actually used in annotations
if name not in type_tracker.discovered_types:
continue
# Check for NamedTuple
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
additional_types.append(name)
# Mark as already imported
type_tracker.already_imported.add(name)
# Check for Enum
elif issubclass(obj, Enum) and name != "Enum":
additional_types.append(name)
# Mark as already imported
type_tracker.already_imported.add(name)
if additional_types:
type_imports = ", ".join([async_class.__name__] + additional_types)
imports.append(f"from {async_class.__module__} import {type_imports}")
else:
imports.append(
f"from {async_class.__module__} import {async_class.__name__}"
)
# Add imports for all discovered types
# Pass the main module name to avoid duplicate imports
imports.extend(
type_tracker.get_imports(main_module_name=async_class.__module__)
)
# Add base module import if needed
if hasattr(inspect.getmodule(async_class), "__name__"):
module_name = inspect.getmodule(async_class).__name__
if "." in module_name:
base_module = module_name.split(".")[0]
# Only add if not already importing from it
if not any(imp.startswith(f"from {base_module}") for imp in imports):
imports.append(f"import {base_module}")
return imports
@classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
# Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr))
elif (
hasattr(async_class, "__annotations__")
and name in async_class.__annotations__
):
annotation = async_class.__annotations__[name]
if isinstance(annotation, type):
class_attributes.append((name, annotation))
return class_attributes
@classmethod
def _generate_inner_class_stub(
cls,
name: str,
attr: Type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
"""Generate stub for an inner class."""
stub_lines = []
stub_lines.append(f"{indent}class {name}Sync:")
# Add docstring if available
if hasattr(attr, "__doc__") and attr.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
)
# Add __init__ if it exists
if hasattr(attr, "__init__"):
try:
init_method = getattr(attr, "__init__")
init_sig = inspect.signature(init_method)
# Try to get type hints
try:
from typing import get_type_hints
init_hints = get_type_hints(init_method)
except Exception:
init_hints = {}
# Format parameters
params_str = cls._format_method_parameters(
init_sig, type_hints=init_hints, type_tracker=type_tracker
)
# Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(
init_method.__doc__, f"{indent} "
)
)
stub_lines.append(
f"{indent} def __init__({params_str}) -> None: ..."
)
except (ValueError, TypeError):
stub_lines.append(
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
)
# Add methods to the inner class
has_methods = False
for method_name, method in sorted(
inspect.getmembers(attr, predicate=inspect.isfunction)
):
if method_name.startswith("_"):
continue
has_methods = True
try:
# Add method docstring if available (before the method signature)
if method.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
)
method_sig = cls._generate_method_signature(
method_name, method, is_async=True, type_tracker=type_tracker
)
stub_lines.append(f"{indent} {method_sig}")
except (ValueError, TypeError):
stub_lines.append(
f"{indent} def {method_name}(self, *args, **kwargs): ..."
)
if not has_methods:
stub_lines.append(f"{indent} pass")
return stub_lines
@classmethod
def _format_docstring_for_stub(
cls, docstring: str, indent: str = " "
) -> list[str]:
"""Format a docstring for inclusion in a stub file with proper indentation."""
if not docstring:
return []
# First, dedent the docstring to remove any existing indentation
dedented = textwrap.dedent(docstring).strip()
# Split into lines
lines = dedented.split("\n")
# Build the properly indented docstring
result = []
result.append(f'{indent}"""')
for line in lines:
if line.strip(): # Non-empty line
result.append(f"{indent}{line}")
else: # Empty line
result.append("")
result.append(f'{indent}"""')
return result
@classmethod
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
"""Post-process stub content to fix any remaining issues."""
processed = []
for line in stub_content:
# Skip processing imports
if line.startswith(("from ", "import ")):
processed.append(line)
continue
# Fix method signatures missing return types
if (
line.strip().startswith("def ")
and line.strip().endswith(": ...")
and ") -> " not in line
):
# Add -> None for methods without return annotation
line = line.replace(": ...", " -> None: ...")
processed.append(line)
return processed
@classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
try:
# Only generate stub if we can determine module path
if async_class.__module__ == "__main__":
return
module = inspect.getmodule(async_class)
if not module:
return
module_path = module.__file__
if not module_path:
return
# Create stub file path in a 'generated' subdirectory
module_dir = os.path.dirname(module_path)
stub_dir = os.path.join(module_dir, "generated")
# Ensure the generated directory exists
os.makedirs(stub_dir, exist_ok=True)
module_name = os.path.basename(module_path)
if module_name.endswith(".py"):
module_name = module_name[:-3]
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
# Create a type tracker for this stub generation
type_tracker = TypeTracker()
stub_content = []
# We'll generate imports after processing all methods to capture all types
# Leave a placeholder for imports
imports_placeholder_index = len(stub_content)
stub_content.append("") # Will be replaced with imports later
# Class definition
stub_content.append(f"class {sync_class.__name__}:")
# Docstring
if async_class.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(async_class.__doc__, " ")
)
# Generate __init__
try:
init_method = async_class.__init__
init_signature = inspect.signature(init_method)
# Try to get type hints for __init__
try:
from typing import get_type_hints
init_hints = get_type_hints(init_method)
except Exception:
init_hints = {}
# Format parameters
params_str = cls._format_method_parameters(
init_signature, type_hints=init_hints, type_tracker=type_tracker
)
# Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(init_method.__doc__, " ")
)
stub_content.append(f" def __init__({params_str}) -> None: ...")
except (ValueError, TypeError):
stub_content.append(
" def __init__(self, *args, **kwargs) -> None: ..."
)
stub_content.append("") # Add newline after __init__
# Get class attributes
class_attributes = cls._get_class_attributes(async_class)
# Generate inner classes
for name, attr in class_attributes:
inner_class_stub = cls._generate_inner_class_stub(
name, attr, type_tracker=type_tracker
)
stub_content.extend(inner_class_stub)
stub_content.append("") # Add newline after the inner class
# Add methods to the main class
processed_methods = set() # Keep track of methods we've processed
for name, method in sorted(
inspect.getmembers(async_class, predicate=inspect.isfunction)
):
if name.startswith("_") or name in processed_methods:
continue
processed_methods.add(name)
try:
method_sig = cls._generate_method_signature(
name, method, is_async=True, type_tracker=type_tracker
)
# Add docstring if available (before the method signature for proper formatting)
if method.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(method.__doc__, " ")
)
stub_content.append(f" {method_sig}")
stub_content.append("") # Add newline after each method
except (ValueError, TypeError):
# If we can't get the signature, just add a simple stub
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
stub_content.append("") # Add newline
# Add properties
for name, prop in sorted(
inspect.getmembers(async_class, lambda x: isinstance(x, property))
):
stub_content.append(" @property")
stub_content.append(f" def {name}(self) -> Any: ...")
if prop.fset:
stub_content.append(f" @{name}.setter")
stub_content.append(
f" def {name}(self, value: Any) -> None: ..."
)
stub_content.append("") # Add newline after each property
# Add placeholders for the nested class instances
# Check the actual attribute names from class annotations and attributes
attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes:
# If the class type matches the annotated type
if (
attr_type == class_type
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
or (isinstance(attr_type, str) and attr_type == class_name)
):
attribute_mappings[class_name] = attr_name
# Remove the extra checking - annotations should be sufficient
# Add the attribute declarations with proper names
for class_name, class_type in class_attributes:
# Check if there's a mapping from annotation
attr_name = attribute_mappings.get(class_name, class_name)
# Use the annotation name if it exists, even if the attribute doesn't exist yet
# This is because the attribute might be created at runtime
stub_content.append(f" {attr_name}: {class_name}Sync")
stub_content.append("") # Add a final newline
# Now generate imports with all discovered types
imports = cls._generate_imports(async_class, type_tracker)
# Deduplicate imports while preserving order
seen = set()
unique_imports = []
for imp in imports:
if imp not in seen:
seen.add(imp)
unique_imports.append(imp)
else:
logging.warning(f"Duplicate import detected: {imp}")
# Replace the placeholder with actual imports
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
unique_imports
)
# Post-process stub content
stub_content = cls._post_process_stub_content(stub_content)
# Write stub file
with open(sync_stub_path, "w") as f:
f.write("\n".join(stub_content))
logging.info(f"Generated stub file: {sync_stub_path}")
except Exception as e:
# If stub generation fails, log the error but don't break the main functionality
logging.error(
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
)
import traceback
logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a sync version of an async class
Args:
async_class: The async class to convert
thread_pool_size: Size of thread pool to use
Returns:
A new class with sync versions of all async methods
"""
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)

View File

@@ -0,0 +1,33 @@
from typing import Type, TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""
if cls not in SingletonMetaclass._instances:
SingletonMetaclass._instances[cls] = super(
SingletonMetaclass, cls
).__call__(*args, **kwargs)
return cls._instances[cls]
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
def __init__(self):
super().__init__()

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
from comfy.cli_args import args
import numpy as np
class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
class Execution(ProxiedSingleton):
async def set_progress(
self,
value: float,
max_value: float,
node_id: str | None = None,
preview_image: Image.Image | ImageInput | None = None,
ignore_size_limit: bool = False,
) -> None:
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
executing_context = get_executing_context()
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
if node_id is None:
raise ValueError("node_id must be provided if not in executing context")
# Convert preview_image to PreviewImageTuple if needed
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
if to_display is not None:
# First convert to PIL Image if needed
if isinstance(to_display, ImageInput):
# Convert ImageInput (torch.Tensor) to PIL Image
# Handle tensor shape [B, H, W, C] -> get first image if batch
tensor = to_display
if len(tensor.shape) == 4:
tensor = tensor[0]
# Convert to numpy array and scale to 0-255
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
to_display = Image.fromarray(image_np)
if isinstance(to_display, Image.Image):
# Detect image format from PIL Image
image_format = to_display.format if to_display.format else "JPEG"
# Use None for preview_size if ignore_size_limit is True
preview_size = None if ignore_size_limit else args.preview_size
to_display = (image_format, to_display, preview_size)
get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
async def get_node_list(self) -> list[type[io.ComfyNode]]:
"""
Returns a list of nodes that this extension provides.
"""
class Input:
Image = ImageInput
Audio = AudioInput
Mask = MaskInput
Latent = LatentInput
Video = VideoInput
class InputImpl:
VideoFromFile = VideoFromFile
VideoFromComponents = VideoFromComponents
class Types:
VideoCodec = VideoCodec
VideoContainer = VideoContainer
VideoComponents = VideoComponents
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
"ComfyExtension",
]

View File

@@ -0,0 +1,10 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
"MaskInput",
"LatentInput",
]

View File

@@ -0,0 +1,42 @@
import torch
from typing import TypedDict, List, Optional
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
MaskInput = torch.Tensor
"""
A mask in format [B, H, W] where B is the batch size
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int
class LatentInput(TypedDict):
"""
TypedDict representing latent input.
"""
samples: torch.Tensor
"""
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
H is the height, and W is the width.
"""
noise_mask: Optional[MaskInput]
"""
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[List[int]]

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.
Returns:
Either a file path (str) or a BytesIO object that can be opened with av.
Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
# Default implementation - subclasses should override for better performance
source = self.get_stream_source()
with av.open(source, mode="r") as container:
return container.format.name

View File

@@ -0,0 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@@ -0,0 +1,324 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)

1618
comfy_api/latest/_io.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

457
comfy_api/latest/_ui.py Normal file
View File

@@ -0,0 +1,457 @@
from __future__ import annotations
import json
import os
import random
from io import BytesIO
from typing import Type
import av
import numpy as np
import torch
import torchaudio
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
def __init__(self, filename: str, subfolder: str, type: FolderType):
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
@property
def filename(self) -> str:
return self["filename"]
@property
def subfolder(self) -> str:
return self["subfolder"]
@property
def type(self) -> FolderType:
return FolderType(self["type"])
class SavedImages(_UIOutput):
"""A UI output class to represent one or more saved images, potentially animated."""
def __init__(self, results: list[SavedResult], is_animated: bool = False):
super().__init__()
self.results = results
self.is_animated = is_animated
def as_dict(self) -> dict:
data = {"images": self.results}
if self.is_animated:
data["animated"] = (True,)
return data
class SavedAudios(_UIOutput):
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
def __init__(self, results: list[SavedResult]):
super().__init__()
self.results = results
def as_dict(self) -> dict:
return {"audio": self.results}
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
if folder_type == FolderType.input:
return folder_paths.get_input_directory()
if folder_type == FolderType.output:
return folder_paths.get_output_directory()
return folder_paths.get_temp_directory()
class ImageSaveHelper:
"""A helper class with static methods to handle image saving and metadata."""
@staticmethod
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
"""Converts a single torch tensor to a PIL Image."""
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add(
b"comf",
"prompt".encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
after_idat=True,
)
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add(
b"comf",
x.encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
after_idat=True,
)
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
return exif_data
if cls.hidden.prompt is not None:
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
if cls.hidden.extra_pnginfo is not None:
inital_exif_tag = 0x010F # EXIF 0x010f = Make
for key, value in cls.hidden.extra_pnginfo.items():
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
inital_exif_tag -= 1
return exif_data
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
results = []
metadata = ImageSaveHelper._create_png_metadata(cls)
for batch_number, image_tensor in enumerate(images):
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
compress_level=compress_level,
)
)
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
file = f"{filename}_{counter:05}_.png"
save_path = os.path.join(full_output_folder, file)
pil_images[0].save(
save_path,
pnginfo=metadata,
compress_level=compress_level,
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
compress_level=compress_level,
)
return SavedImages([result], is_animated=len(images) > 1)
@staticmethod
def save_animated_webp(
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedResult:
"""Saves a batch of images as a single animated WebP."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
file = f"{filename}_{counter:05}_.webp"
pil_images[0].save(
os.path.join(full_output_folder, file),
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
exif=pil_exif,
lossless=lossless,
quality=quality,
method=method,
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedImages:
"""Saves an animated WebP and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_webp(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=method,
)
return SavedImages([result], is_animated=len(images) > 1)
class AudioSaveHelper:
"""A helper class with static methods to handle audio saving and metadata."""
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
@staticmethod
def save_audio(
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type)
)
metadata = {}
if not args.disable_metadata and cls is not None:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(AudioSaveHelper._OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, "wb") as f:
f.write(output_buffer.getbuffer())
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
AudioSaveHelper.save_audio(
audio,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
format=format,
quality=quality,
)
)
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
compress_level=1,
)
self.animated = animated
def as_dict(self):
return {
"images": self.values,
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs)
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
format="flac",
quality="128k",
)
def as_dict(self) -> dict:
return {"audio": self.values}
class PreviewVideo(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
return {"images": self.values, "animated": (True,)}
class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
def as_dict(self):
return {"result": [self.model_file, self.camera_info]}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText

View File

@@ -0,0 +1,8 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",
]

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
H264 = "h264"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of codec names that can be used as node input.
"""
return [member.value for member in cls]
class VideoContainer(str, Enum):
AUTO = "auto"
MP4 = "mp4"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of container names that can be used as node input.
"""
return [member.value for member in cls]
@classmethod
def get_extension(cls, value) -> str:
"""
Returns the file extension for the container.
"""
if isinstance(value, str):
value = cls(value)
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
return "mp4"
return ""
@dataclass
class VideoComponents:
"""
Dataclass representing the components of a video.
"""
images: ImageInput
frame_rate: Fraction
audio: Optional[AudioInput] = None
metadata: Optional[dict] = None

View File

@@ -0,0 +1,20 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
from comfy_api.latest import ComfyAPI_latest
from PIL.Image import Image
from torch import Tensor
class ComfyAPISyncStub:
def __init__(self) -> None: ...
class ExecutionSync:
def __init__(self) -> None: ...
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
execution: ExecutionSync

8
comfy_api/util.py Normal file
View File

@@ -0,0 +1,8 @@
# This file only exists for backwards compatibility.
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
__all__ = [
"VideoCodec",
"VideoContainer",
"VideoComponents",
]

View File

@@ -1,7 +1,7 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
# This file only exists for backwards compatibility.
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",

View File

@@ -1,51 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
H264 = "h264"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of codec names that can be used as node input.
"""
return [member.value for member in cls]
class VideoContainer(str, Enum):
AUTO = "auto"
MP4 = "mp4"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of container names that can be used as node input.
"""
return [member.value for member in cls]
@classmethod
def get_extension(cls, value) -> str:
"""
Returns the file extension for the container.
"""
if isinstance(value, str):
value = cls(value)
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
return "mp4"
return ""
@dataclass
class VideoComponents:
"""
Dataclass representing the components of a video.
"""
images: ImageInput
frame_rate: Fraction
audio: Optional[AudioInput] = None
metadata: Optional[dict] = None
# This file only exists for backwards compatibility.
from comfy_api.latest._util.video_types import (
VideoContainer,
VideoCodec,
VideoComponents,
)
__all__ = [
"VideoContainer",
"VideoCodec",
"VideoComponents",
]

View File

@@ -0,0 +1,42 @@
from comfy_api.v0_0_2 import (
ComfyAPIAdapter_v0_0_2,
Input as Input_v0_0_2,
InputImpl as InputImpl_v0_0_2,
Types as Types_v0_0_2,
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
# This version only exists to serve as a template for future version adapters.
# There is no reason anyone should ever use it.
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
VERSION = "0.0.1"
STABLE = True
class Input(Input_v0_0_2):
pass
class InputImpl(InputImpl_v0_0_2):
pass
class Types(Types_v0_0_2):
pass
ComfyAPI = ComfyAPIAdapter_v0_0_1
# Create a synchronous version of the API
if TYPE_CHECKING:
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
]

View File

@@ -0,0 +1,20 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from PIL.Image import Image
from torch import Tensor
class ComfyAPISyncStub:
def __init__(self) -> None: ...
class ExecutionSync:
def __init__(self) -> None: ...
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
execution: ExecutionSync

View File

@@ -0,0 +1,45 @@
from comfy_api.latest import (
ComfyAPI_latest,
Input as Input_latest,
InputImpl as InputImpl_latest,
Types as Types_latest,
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
VERSION = "0.0.2"
STABLE = False
class Input(Input_latest):
pass
class InputImpl(InputImpl_latest):
pass
class Types(Types_latest):
pass
ComfyAPI = ComfyAPIAdapter_v0_0_2
# Create a synchronous version of the API
if TYPE_CHECKING:
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
"ComfyExtension",
]

View File

@@ -0,0 +1,20 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from PIL.Image import Image
from torch import Tensor
class ComfyAPISyncStub:
def __init__(self) -> None: ...
class ExecutionSync:
def __init__(self) -> None: ...
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
execution: ExecutionSync

12
comfy_api/version_list.py Normal file
View File

@@ -0,0 +1,12 @@
from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,
]

View File

@@ -2,7 +2,7 @@
## Introduction
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
## Development

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
import datetime
import json

View File

@@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')

View File

@@ -2,6 +2,8 @@
API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
from __future__ import annotations
import os
from enum import Enum
@@ -406,7 +408,7 @@ class GeminiInputFiles(ComfyNodeABC):
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.pdf
GeminiMimeType.application_pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
)

View File

@@ -132,6 +132,8 @@ def poll_until_finished(
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()

View File

@@ -0,0 +1,743 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import random
import torch
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input.video_types import VideoInput
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl import VideoFromFile
import av
import io
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
MIN_WIDTH = 300
MIN_HEIGHT = 300
MAX_WIDTH = 10000
MAX_HEIGHT = 10000
MIN_VID_WIDTH = 300
MIN_VID_HEIGHT = 300
MAX_VID_WIDTH = 10000
MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
"""Verifies that the initial response contains a task ID."""
return bool(response.id)
def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise MoonvalleyApiError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status if response and response.status else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
Args:
video: Input video to validate
Returns:
Validated and potentially trimmed video
Raises:
ValueError: If video doesn't meet requirements
MoonvalleyApiError: If video duration is too short
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
_validate_container_format(video)
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152),
(1536, 1152), (1152, 1536)
}
if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
return _trim_if_too_long(video, duration)
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
"h264", rate=stream.average_rate
)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream(
"aac", rate=stream.sample_rate
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
# --- BaseMoonvalleyVideoNode ---
class BaseMoonvalleyVideoNode:
def parseWidthHeightFromRes(self, resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
return res_map[resolution]
else:
# Default to 1920x1080 if unknown
return {"width": 1920, "height": 1080}
def parseControlParameter(self, value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control",
}
if value in control_map:
return control_map[value]
else:
return control_map["Motion Transfer"]
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoRequest,
"prompt_text",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
),
"resolution": (
IO.COMBO,
{
"options": [
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1440 x 1080)",
"3:4 (1080 x 1440)",
"21:9 (2560 x 1080)",
],
"default": "16:9 (1920 x 1080)",
"tooltip": "Resolution of the output video",
},
),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyTextToVideoInferenceParams,
"guidance_scale",
default=7.0,
step=1,
min=1,
max=20,
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
"seed",
default=random.randint(0, 2**32 - 1),
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=True,
),
"steps": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
"steps",
default=100,
min=1,
max=100,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate"
CATEGORY = "api node/video/Moonvalley Marey"
API_NODE = True
def generate(self, **kwargs):
return None
# --- MoonvalleyImg2VideoNode ---
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
@classmethod
def INPUT_TYPES(cls):
return super().INPUT_TYPES()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node"
def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
image = kwargs.get("image", None)
if image is None:
raise MoonvalleyApiError("image is required")
validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video,)
# --- MoonvalleyVid2VidNode ---
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
multiline=True
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
)
}
}
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
if not video:
raise MoonvalleyApiError("video is required")
video_url = ""
if video:
validated_video = validate_video_to_video_input(video)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity
inference_params=MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=kwargs.get("seed"),
control_params=control_params
)
control = self.parseControlParameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video,)
# --- MoonvalleyTxt2VideoNode ---
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
# Remove image-specific parameters
for param in ["image"]:
if param in input_types["optional"]:
del input_types["optional"][param]
return input_types
def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video,)
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}

View File

@@ -8,10 +8,10 @@ from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
Veo2GenVidRequest,
Veo2GenVidResponse,
Veo2GenVidPollRequest,
Veo2GenVidPollResponse
VeoGenVidRequest,
VeoGenVidResponse,
VeoGenVidPollRequest,
VeoGenVidPollResponse
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor):
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
@@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC):
"default": None,
"tooltip": "Optional reference image to guide video generation",
}),
"model": (
IO.COMBO,
{
"options": ["veo-2.0-generate-001"],
"default": "veo-2.0-generate-001",
"tooltip": "Veo 2 model to use for video generation",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
@@ -141,7 +149,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "generate_video"
CATEGORY = "api node/video/Veo"
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
API_NODE = True
def generate_video(
@@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
person_generation="ALLOW",
seed=0,
image=None,
model="veo-2.0-generate-001",
generate_audio=False,
unique_id: Optional[str] = None,
**kwargs,
):
@@ -188,16 +198,19 @@ class VeoVideoGenerationNode(ComfyNodeABC):
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
# Only add generateAudio for Veo 3 models
if "veo-3.0" in model:
parameters["generateAudio"] = generate_audio
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/veo/generate",
path=f"/proxy/veo/{model}/generate",
method=HttpMethod.POST,
request_model=Veo2GenVidRequest,
response_model=Veo2GenVidResponse
request_model=VeoGenVidRequest,
response_model=VeoGenVidResponse
),
request=Veo2GenVidRequest(
request=VeoGenVidRequest(
instances=instances,
parameters=parameters
),
@@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC):
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/veo/poll",
path=f"/proxy/veo/{model}/poll",
method=HttpMethod.POST,
request_model=Veo2GenVidPollRequest,
response_model=Veo2GenVidPollResponse
request_model=VeoGenVidPollRequest,
response_model=VeoGenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
status_extractor=status_extractor,
progress_extractor=progress_extractor,
request=Veo2GenVidPollRequest(
request=VeoGenVidPollRequest(
operationName=operation_name
),
auth_kwargs=kwargs,
@@ -298,11 +311,64 @@ class VeoVideoGenerationNode(ComfyNodeABC):
return (VideoFromFile(video_io),)
# Register the node
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
"""
Generates videos from text prompts using Google's Veo 3 API.
Supported models:
- veo-3.0-generate-001
- veo-3.0-fast-generate-001
This node extends the base Veo node with Veo 3 specific features including
audio generation and fixed 8-second duration.
"""
@classmethod
def INPUT_TYPES(s):
parent_input = super().INPUT_TYPES()
# Update model options for Veo 3
parent_input["optional"]["model"] = (
IO.COMBO,
{
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
"default": "veo-3.0-generate-001",
"tooltip": "Veo 3 model to use for video generation",
},
)
# Add generateAudio parameter
parent_input["optional"]["generate_audio"] = (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
}
)
# Update duration constraints for Veo 3 (only 8 seconds supported)
parent_input["optional"]["duration_seconds"] = (
IO.INT,
{
"default": 8,
"min": 8,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
},
)
return parent_input
# Register the nodes
NODE_CLASS_MAPPINGS = {
"VeoVideoGenerationNode": VeoVideoGenerationNode,
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
}

View File

@@ -11,6 +11,43 @@ from comfy_config.types import (
PyProjectSettings
)
def validate_and_extract_os_classifiers(classifiers: list) -> list:
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
if not os_classifiers:
return []
os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}
for os_value in os_values:
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
return []
return os_values
def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
if not accelerator_classifiers:
return []
accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]
valid_accelerators = {
"GPU :: NVIDIA CUDA",
"GPU :: AMD ROCm",
"GPU :: Intel Arc",
"NPU :: Huawei Ascend",
"GPU :: Apple Metal",
}
for accelerator_value in accelerator_values:
if accelerator_value not in valid_accelerators:
return []
return accelerator_values
"""
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
@@ -78,6 +115,24 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
tool_data = raw_settings.tool
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
dependencies = project_data.get("dependencies", [])
supported_comfyui_frontend_version = ""
for dep in dependencies:
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
break
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
classifiers = project_data.get('classifiers', [])
supported_os = validate_and_extract_os_classifiers(classifiers)
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
project_data['supported_os'] = supported_os
project_data['supported_accelerators'] = supported_accelerators
project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version
project_data['supported_comfyui_version'] = supported_comfyui_version
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)

View File

@@ -51,7 +51,7 @@ class ComfyConfig(BaseModel):
models: List[Model] = Field(default_factory=list, alias="Models")
includes: List[str] = Field(default_factory=list)
web: Optional[str] = None
banner_url: str = ""
class License(BaseModel):
file: str = ""
@@ -66,6 +66,10 @@ class ProjectConfig(BaseModel):
dependencies: List[str] = Field(default_factory=list)
license: License = Field(default_factory=License)
urls: URLs = Field(default_factory=URLs)
supported_os: List[str] = Field(default_factory=list)
supported_accelerators: List[str] = Field(default_factory=list)
supported_comfyui_version: str = ""
supported_comfyui_frontend_version: str = ""
@field_validator('license', mode='before')
@classmethod

Some files were not shown because too many files have changed in this diff Show More