Compare commits

...

155 Commits

Author SHA1 Message Date
comfyanonymous
9a470e073e ComfyUI version 0.3.45 2025-07-21 14:05:43 -04:00
ComfyUI Wiki
7d627f764c Update template to 0.1.39 (#8981) 2025-07-20 15:58:35 -04:00
comfyanonymous
a0c0785635 Document what the fast_fp16_accumulation is in the portable. (#8973) 2025-07-20 01:24:09 -04:00
chaObserv
100c2478ea Add SamplingPercentToSigma node (#8963)
It's helpful to adjust start_percent or end_percent based on the corresponding sigma.
2025-07-19 23:09:11 -04:00
ComfyUI Wiki
1da5639e86 Update template to 0.1.37 (#8967) 2025-07-19 06:08:00 -04:00
comfyanonymous
1b96fae1d4 Add nested style of dual cfg to DualCFGGuider node. (#8965) 2025-07-19 04:55:23 -04:00
comfyanonymous
7f492522b6 Forgot this (#8957) 2025-07-18 05:43:02 -04:00
comfyanonymous
650838fd6f Experimental CFGNorm node. (#8942)
This is from the new hidream e1 1 model code. Figured it might be useful as a generic cfg trick.
2025-07-17 04:11:07 -04:00
comfyanonymous
491fafbd64 Silence clip tokenizer warning. (#8934) 2025-07-16 14:42:07 -04:00
Harel Cain
9bc2798f72 LTXV VAE decoder: switch default padding mode (#8930) 2025-07-16 13:54:38 -04:00
comfyanonymous
50afba747c Add attempt to work around the safetensors mmap issue. (#8928) 2025-07-16 03:42:17 -04:00
Brandon Wallace
6b8062f414 Fix MaskComposite error when destination has 2 dimensions (#8915)
Fix code that is using the original `destination` input instead of the reshaped value.
2025-07-15 21:08:27 -04:00
comfyanonymous
b1ae4126c3 Add action to detect windows line endings. (#8917) 2025-07-15 02:27:18 -04:00
Yoland Yan
9dabda19f0 Update nodes_gemini.py (#8912) 2025-07-14 20:59:35 -04:00
Yoland Yan
543c24108c Fix wrong reference bug (#8910) 2025-07-14 20:45:55 -04:00
FeepingCreature
260a5ca5d9 Allow the prompt request to specify the prompt ID. (#8189)
This makes it easier to write asynchronous clients that submit requests, because they can store the task immediately.
Duplicate prompt IDs are rejected by the job queue.
2025-07-14 14:48:31 -04:00
ComfyUI Wiki
861c3bbb3d Upate template to 0.1.36 (#8904) 2025-07-14 13:27:57 -04:00
comfyanonymous
9ca581c941 Remove windows line endings. (#8902) 2025-07-14 13:10:20 -04:00
comfyanonymous
4831e9c2c4 Refactor previous pr. (#8893) 2025-07-13 04:59:17 -04:00
Christian Byrne
480375f349 Remove auth tokens from history storage (#8889)
Remove auth_token_comfy_org and api_key_comfy_org from extra_data before
storing prompt history to prevent sensitive authentication tokens from
being persisted in the history endpoint response.
2025-07-13 04:46:27 -04:00
comfyanonymous
b40143984c Add model detection error hint for lora. (#8880) 2025-07-12 03:49:26 -04:00
chaObserv
b43916a134 Fix fresca's input and output (#8871) 2025-07-11 12:52:58 -04:00
JettHu
7bc7dd2aa2 Execute async node earlier (#8865) 2025-07-11 12:51:06 -04:00
comfyanonymous
938d3e8216 Remove windows line endings. (#8866) 2025-07-11 02:37:51 -04:00
Christian Byrne
8f05fb48ea [fix] increase Kling API polling timeout to prevent user timeouts (#8860)
Extends polling duration from 10 minutes to ~68 minutes (256 attempts × 16 seconds) to accommodate longer Kling API operations that were frequently timing out for users.
2025-07-10 18:00:29 -04:00
comfyanonymous
b7ff5bd14d Fix python3.9 (#8858) 2025-07-10 15:21:18 -04:00
guill
2b653e8c18 Support for async node functions (#8830)
* Support for async execution functions

This commit adds support for node execution functions defined as async. When
a node's execution function is defined as async, we can continue
executing other nodes while it is processing.

Standard uses of `await` should "just work", but people will still have
to be careful if they spawn actual threads. Because torch doesn't really
have async/await versions of functions, this won't particularly help
with most locally-executing nodes, but it does work for e.g. web
requests to other machines.

In addition to the execute function, the `VALIDATE_INPUTS` and
`check_lazy_status` functions can also be defined as async, though we'll
only resolve one node at a time right now for those.

* Add the execution model tests to CI

* Add a missing file

It looks like this got caught by .gitignore? There's probably a better
place to put it, but I'm not sure what that is.

* Add the websocket library for automated tests

* Add additional tests for async error cases

Also fixes one bug that was found when an async function throws an error
after being scheduled on a task.

* Add a feature flags message to reduce bandwidth

We now only send 1 preview message of the latest type the client can
support.

We'll add a console warning when the client fails to send a feature
flags message at some point in the future.

* Add async tests to CI

* Don't actually add new tests in this PR

Will do it in a separate PR

* Resolve unit test in GPU-less runner

* Just remove the tests that GHA can't handle

* Change line endings to UNIX-style

* Avoid loading model_management.py so early

Because model_management.py has a top-level `logging.info`, we have to
be careful not to import that file before we call `setup_logging`. If we
do, we end up having the default logging handler registered in addition
to our custom one.
2025-07-10 14:46:19 -04:00
comfyanonymous
1fd306824d Add warning to catch torch import mistakes. (#8852) 2025-07-10 01:03:27 -04:00
Kohaku-Blueleaf
1205afc708 Better training loop implementation (#8820) 2025-07-09 11:41:22 -04:00
comfyanonymous
5612670ee4 Remove unmaintained notebook. (#8845) 2025-07-09 03:45:48 -04:00
Kohaku-Blueleaf
181a9bf26d Support Multi Image-Caption dataset in lora training node (#8819)
* initial impl of multi img/text dataset

* Update nodes_train.py

* Support Kohya-ss structure
2025-07-08 20:18:04 -04:00
chaObserv
aac10ad23a Add SA-Solver sampler (#8834) 2025-07-08 16:17:06 -04:00
josephrocca
974254218a Un-hardcode chroma patch_size (#8840) 2025-07-08 15:56:59 -04:00
comfyanonymous
c5de4955bb ComfyUI version 0.3.44 2025-07-08 08:56:38 -04:00
Christian Byrne
9fd0cd7cf7 Add Moonvalley nodes (#8832) 2025-07-08 08:54:30 -04:00
ComfyUI Wiki
b5e97db9ac Update template to 0.1.35 (#8831) 2025-07-08 08:52:02 -04:00
Christian Byrne
1359c969e4 Update template to 0.1.34 (#8829) 2025-07-07 23:35:41 -04:00
ComfyUI Wiki
059cd38aa2 Update template and node docs package version (#8825) 2025-07-07 20:43:56 -04:00
comfyanonymous
e740dfd806 Fix warning in audio save nodes. (#8818) 2025-07-07 03:16:00 -04:00
comfyanonymous
7eab7d2944 Remove dependency on deprecated torchaudio.save function (#8815) 2025-07-06 14:01:32 -04:00
comfyanonymous
75d327abd5 Remove some useless code. (#8812) 2025-07-06 07:07:39 -04:00
comfyanonymous
ee615ac269 Add warning when loading file unsafely. (#8800) 2025-07-05 14:34:57 -04:00
comfyanonymous
27870ec3c3 Add that ckpt files are loaded safely to README. (#8791) 2025-07-04 04:49:11 -04:00
chaObserv
f41f323c52 Add the denoising step to several samplers (#8780) 2025-07-03 19:20:53 -04:00
comfyanonymous
f74fc4d927 Add ImageRotate and ImageFlip nodes. (#8789) 2025-07-03 19:16:30 -04:00
ComfyUI Wiki
ae26cd99b5 Update template to 0.1.32 (#8782) 2025-07-03 14:41:16 -04:00
comfyanonymous
e9af97ba1a Use torch cu129 for nvidia pytorch nightly. (#8786)
* update nightly workflow with cu129

* Remove unused file to lower standalone size.
2025-07-03 14:39:11 -04:00
City
d9277301d2 Initial code for new SLG node (#8759) 2025-07-02 20:13:43 -04:00
comfyanonymous
34c8eeec06 Fix ImageColorToMask not returning right mask values. (#8771) 2025-07-02 15:35:11 -04:00
Harel Cain
9f1069290c nodes_lt: fixes to latent conditioning at index > 0 (#8769) 2025-07-02 15:34:51 -04:00
comfyanonymous
111f583e00 Fallback to regular op when fp8 op throws exception. (#8761) 2025-07-02 00:57:13 -04:00
Terry Jia
79ed752748 support upload 3d model to custom subfolder (#8597) 2025-07-01 20:43:48 -04:00
comfyanonymous
772de7c006 PerpNeg Guider optimizations. (#8753) 2025-07-01 03:09:07 -04:00
chaObserv
b22e97dcfa Migrate ER-SDE from VE to VP algorithm and add its sampler node (#8744)
Apply alpha scaling in the algorithm for reverse-time SDE and add custom ER-SDE sampler node for other solver types (SDE, ODE).
2025-07-01 02:38:52 -04:00
chaObserv
f02de13316 Add TCFG node (#8730) 2025-07-01 02:33:07 -04:00
ComfyUI Wiki
c46268bf60 Update requirements.txt (#8741) 2025-06-30 14:18:43 -04:00
comfyanonymous
cf49a2c5b5 Dual cfg node optimizations when cfg is 1.0 (#8747) 2025-06-30 14:18:25 -04:00
comfyanonymous
170c7bb90c Fix contiguous issue with pytorch nightly. (#8729) 2025-06-29 06:38:40 -04:00
bmcomfy
2a0b138feb build: add gh action to process releases (#8652) 2025-06-28 19:11:40 -04:00
comfyanonymous
e195c1b13f Make stable release workflow publish drafts. (#8723) 2025-06-28 19:11:16 -04:00
chaObserv
5b4eb021cb Perpneg guider with updated pre and post-cfg (#8698) 2025-06-28 18:13:13 -04:00
comfyanonymous
396454fa41 Reorder the schedulers so simple is the default one. (#8722) 2025-06-28 18:12:56 -04:00
comfyanonymous
a3cf272522 Skip custom node logic completely if disabled and no whitelisted nodes. (#8719) 2025-06-28 15:53:40 -04:00
xufeng
ba9548f756 “--whitelist-custom-nodes” args for comfy core to go with “--disable-all-custom-nodes” for development purposes (#8592)
* feat: “--whitelist-custom-nodes” args for comfy core to go with “--disable-all-custom-nodes” for development purposes

* feat: Simplify custom nodes whitelist logic to use consistent code paths
2025-06-28 15:24:02 -04:00
comfyanonymous
e18f53cca9 ComfyUI version 0.3.43 2025-06-27 17:22:02 -04:00
comfyanonymous
c36be0ea09 Fix memory estimation bug with kontext. (#8709) 2025-06-27 17:21:12 -04:00
comfyanonymous
9093301a49 Don't add tiny bit of random noise when VAE encoding. (#8705)
Shouldn't change outputs but might make things a tiny bit more
deterministic.
2025-06-27 14:14:56 -04:00
comfyanonymous
bd951a714f Add Flux Kontext and Omnigen 2 models to readme. (#8682) 2025-06-26 12:26:29 -04:00
comfyanonymous
6493709d6a ComfyUI version 0.3.42 2025-06-26 11:47:07 -04:00
filtered
b976f934ae Update frontend to 1.23.4 (#8681) 2025-06-26 11:44:12 -04:00
comfyanonymous
7d8cf4cacc Update requirements.txt (#8680) 2025-06-26 11:39:40 -04:00
filtered
68f4496b8e Update frontend to 1.23.3 (#8678) 2025-06-26 11:29:03 -04:00
comfyanonymous
ef5266b1c1 Support Flux Kontext Dev model. (#8679) 2025-06-26 11:28:41 -04:00
comfyanonymous
a96e65df18 Disable omnigen2 fp16 on older pytorch versions. (#8672) 2025-06-26 03:39:09 -04:00
comfyanonymous
93a49a45de Bump minimum transformers version. (#8671) 2025-06-26 02:33:02 -04:00
comfyanonymous
ec70ed6aea Omnigen2 model implementation. (#8669) 2025-06-25 19:35:57 -04:00
comfyanonymous
7a13f74220 unet -> diffusion model (#8659) 2025-06-25 04:52:34 -04:00
chaObserv
8042eb20c6 Singlestep DPM++ SDE for RF (#8627)
Refactor the algorithm, and apply alpha scaling.
2025-06-24 14:59:09 -04:00
comfyanonymous
bd9f166c12 Cosmos predict2 model merging nodes. (#8647) 2025-06-24 05:17:16 -04:00
comfyanonymous
dd94416db2 Indicate that directml is not recommended in the README. (#8644) 2025-06-23 14:04:49 -04:00
comfyanonymous
ae0e7c4dff Resize and pad image node. (#8636) 2025-06-22 17:59:31 -04:00
comfyanonymous
78f79266a9 Allow padding in ImageStitch node to be white. (#8631) 2025-06-22 00:19:41 -04:00
comfyanonymous
1883e70b43 Fix exception when using a noise mask with cosmos predict2. (#8621)
* Fix exception when using a noise mask with cosmos predict2.

* Fix ruff.
2025-06-21 03:30:39 -04:00
Lucas - BLOCK33
31ca603ccb Improve the log time function for 10 minute + renders (#6207)
* modified:   main.py

* Update main.py
2025-06-20 23:04:55 -04:00
comfyanonymous
f7fb193712 Small flux optimization. (#8611) 2025-06-20 05:37:32 -04:00
comfyanonymous
7e9267fa77 Make flux controlnet work with sd3 text enc. (#8599) 2025-06-19 18:50:05 -04:00
comfyanonymous
91d40086db Fix pytorch warning. (#8593) 2025-06-19 11:04:52 -04:00
coderfromthenorth93
5b12b55e32 Add new fields to the config types (#8507) 2025-06-18 15:12:29 -04:00
comfyanonymous
e9e9a031a8 Show a better error when the workflow OOMs. (#8574) 2025-06-18 06:55:21 -04:00
filtered
d7430c529a Update frontend to 1.22.2 (#8567) 2025-06-17 18:58:28 -04:00
ComfyUI Wiki
cd88f709ab Update template version (#8563) 2025-06-17 04:11:59 -07:00
comfyanonymous
4459a17e82 Add Cosmos Predict2 to README. (#8562) 2025-06-17 05:18:01 -04:00
comfyanonymous
483b3e62e0 ComfyUI version v0.3.41 2025-06-16 23:34:46 -04:00
chaObserv
8e81c507d2 Multistep DPM++ SDE samplers for RF (#8541)
Include alpha in sampling and minor refactoring
2025-06-16 14:47:10 -04:00
comfyanonymous
e1c6dc720e Allow setting min_length with tokenizer_data. (#8547) 2025-06-16 13:43:52 -04:00
comfyanonymous
7ea79ebb9d Add correct eps to ltxv rmsnorm. (#8542) 2025-06-15 12:21:25 -04:00
comfyanonymous
ae75a084df SaveLora now saves in the same filename format as all the other nodes. (#8538) 2025-06-15 03:44:59 -04:00
comfyanonymous
d6a2137fc3 Support Cosmos predict2 image to video models. (#8535)
Use the CosmosPredict2ImageToVideoLatent node.
2025-06-14 21:37:07 -04:00
chaObserv
53e8d8193c Generalize SEEDS samplers (#8529)
Restore VP algorithm for RF and refactor noise_coeffs and half-logSNR calculations
2025-06-14 16:58:16 -04:00
comfyanonymous
29596bd53f Small cosmos attention code refactor. (#8530) 2025-06-14 05:02:05 -04:00
Terry Jia
803af1e0c3 allow extra settings from pyproject.toml (#8526) 2025-06-13 23:11:55 -04:00
ComfyUI Wiki
6673939e76 Bump template to 0.1.28 (#8510) 2025-06-13 23:11:00 -04:00
ComfyUI Wiki
f74778e75d Bump embedded docs to 0.2.2 (#8512) 2025-06-13 23:06:28 -04:00
Kohaku-Blueleaf
520eb77b72 LoRA Trainer: LoRA training node in weight adapter scheme (#8446) 2025-06-13 19:25:59 -04:00
comfyanonymous
5bf69bde35 Add cosmos_rflow option to ModelSamplingContinuousEDM node. (#8523)
This is for the cosmos predict2 model.
2025-06-13 17:47:52 -04:00
comfyanonymous
c69af655aa Uncap cosmos predict2 res and fix mem estimation. (#8518) 2025-06-13 07:30:18 -04:00
comfyanonymous
251f54a2ad Basic initial support for cosmos predict2 text to image 2B and 14B models. (#8517) 2025-06-13 07:05:23 -04:00
Christian Byrne
c6529c0d77 don't validate string inputs with VALIDATE_INPUTS (#8508) 2025-06-12 20:17:10 -04:00
filtered
baa8c8cdd3 Add '@prerelease' to use latest test frontend (#8501)
* Add '@prerelease' to use latest test frontend

Allows download of pre-release versions.

Will always get the latest pre-release version - even if it's older than the latest stable release.

* nit
2025-06-12 17:03:27 -07:00
comfyanonymous
40fd39c7cb debug -> warning (#8506) 2025-06-12 17:14:59 -04:00
Terry Jia
4d1c4b9797 Auto register web folder (#8505)
* auto register web folder from pyproject

* need pydantic-settings as dependency

* wrapped try/except for config_parser

* sf
2025-06-12 16:24:39 -04:00
comfyanonymous
d2566eb4b2 Add a warning for old python versions. (#8504) 2025-06-12 15:38:33 -04:00
filtered
ef7e885fe4 Revert "Update requirements.txt (#8487)" (#8502)
This reverts commit 373a9386a4.
2025-06-12 14:10:48 -04:00
filtered
ecb8d15e7a Allow specifying any frontend semver suffixes (#8498) 2025-06-11 21:41:30 -04:00
comfyanonymous
365f9ed157 Revert "auto register web folder from pyproject (#8478)" (#8497)
This reverts commit 9685d4f3c3.
2025-06-11 17:28:04 -04:00
pythongosssss
50c605e957 Add support for sqlite database (#8444)
* Add support for sqlite database

* fix
2025-06-11 16:43:39 -04:00
Terry Jia
9685d4f3c3 auto register web folder from pyproject (#8478)
* auto register web folder from pyproject

* need pydantic-settings as dependency
2025-06-11 16:21:28 -04:00
comfyanonymous
8a4ff747bd Fix mistake in last commit. (#8496)
* Move to right place.
2025-06-11 15:13:29 -04:00
comfyanonymous
af1eb58be8 Fix black images on some flux models in fp16. (#8495) 2025-06-11 15:09:11 -04:00
ComfyUI Wiki
373a9386a4 Update requirements.txt (#8487) 2025-06-11 05:10:46 -04:00
comfyanonymous
6e28a46454 Apple most likely is never fixing the fp16 attention bug. (#8485) 2025-06-10 13:06:24 -04:00
Kent Mewhort
c7b25784b1 Fix WebcamCapture IS_CHANGED signature (#8413) 2025-06-09 13:05:54 -04:00
comfyanonymous
7f800d04fa Enable AMD fp8 and pytorch attention on some GPUs. (#8474)
Information is from the pytorch source code.
2025-06-09 12:50:39 -04:00
comfyanonymous
97755eed46 Enable fp8 ops by default on gfx1201 (#8464) 2025-06-08 14:15:34 -04:00
comfyanonymous
daf9d25ee2 Cleaner torch version comparisons. (#8453) 2025-06-07 10:01:15 -04:00
comfyanonymous
3b4b171e18 Alternate fix for #8435 (#8442) 2025-06-06 09:43:27 -04:00
Olexandr88
d8759c772b Update README.md (#8427) 2025-06-05 10:44:29 -07:00
comfyanonymous
4248b1618f Let chroma TE work on regular flux. (#8429) 2025-06-05 10:07:17 -04:00
comfyanonymous
866f6cdab4 ComfyUI version 0.3.40 2025-06-04 22:18:54 -04:00
Christian Byrne
3aa83feeec [refactor] remove version prefixes from Ideogram node categories (#8418)
Simplifies node organization by consolidating all Ideogram nodes under a single category instead of version-specific subcategories.
2025-06-04 21:56:38 -04:00
comfyanonymous
871749c208 Add batch to GetImageSize node. (#8419) 2025-06-04 09:40:21 -04:00
SD
fcc1643c52 Sub call to deprecated pillow API Image.ANTIALIAS (#8415)
ANTIALIAS was removed in Pillow 10.0.0
2025-06-04 09:03:42 -04:00
filtered
20687293fe Update frontend to 1.21.7 (#8410) 2025-06-04 08:57:13 -04:00
Terry Jia
47d55b8b45 add support to read pyproject.toml from custom node (#8357)
* add support to read pyproject.toml from custom node

* sf

* use pydantic instead

* sf

* use pydantic_settings

* remove unnecessary try/catch and handle single-file python node

* sf
2025-06-03 19:59:13 -04:00
comfyanonymous
310f4b6ef8 Add api nodes to readme. (#8402) 2025-06-03 04:26:44 -04:00
Christian Byrne
856448060c [feat] Add GetImageSize node (#8386)
* [feat] Add GetImageSize node to return image dimensions

Added a simple GetImageSize node in comfy_extras/nodes_images.py that returns width and height of input images. The node displays dimensions on the UI via PromptServer and provides width/height as outputs for further processing.

* add display name mapping

* [fix] Add server module mock to unit tests for PromptServer import

Updated test to mock server module preventing import errors from the new PromptServer usage in GetImageSize node. Uses direct import pattern consistent with rest of codebase.
2025-06-02 21:57:50 -04:00
comfyanonymous
312d511630 Style fix. (#8390) 2025-06-02 07:22:02 -04:00
Jesse Gonyou
4f4f1c642a Update fix for potential XSS on /view (#8384)
* Update fix for potential XSS on /view

This commit uses mimetypes to add more restricted filetypes to prevent from being served, since mimetypes are what browsers use to determine how to serve files.

* Fix typo

Fixed a typo that prevented the program from running
2025-06-02 06:52:44 -04:00
filtered
010954d277 [BugFix] Update frontend to 1.21.6 (#8383) 2025-06-02 14:57:44 +10:00
filtered
6d46bb4b4c [BugFix] Update frontend to 1.21.5 (#8382) 2025-06-01 16:47:14 -04:00
Christian Byrne
67f57c5bcc [feat] add custom node testing requirement to issue templates (#8374)
Adds mandatory checkbox to bug report and user support templates requiring users to confirm they've tested with custom nodes disabled before submitting issues.
2025-06-01 15:47:07 -04:00
filtered
fd943c928f [BugFix] Update frontend to 1.21.4 (#8377) 2025-06-01 13:57:53 -04:00
ComfyUI Wiki
d3bd983b91 Bump template to 0.1.25 (#8372) 2025-06-01 05:41:17 -04:00
comfyanonymous
fb4754624d Make the casting in lists the same as regular inputs. (#8373) 2025-06-01 05:39:54 -04:00
Benjamin Lu
180db6753f Add Help Menu in NodeLibrarySidebarTab (#8179) 2025-06-01 04:32:32 -04:00
Christian Byrne
d062fcc5c0 [feat] Add ImageStitch node for concatenating images (#8369)
* [feat] Add ImageStitch node for concatenating images with borders

Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing.

Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage.

* [fix] Fix CI issues with CUDA dependencies and linting

- Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners
- Fix ruff linting issues for code style compliance

* [fix] Improve CI compatibility by mocking nodes module import

Prevent CUDA initialization chain by mocking the nodes module at import time,
which is cleaner than deep mocking of CUDA-specific functions.

* [refactor] Clean up ImageStitch tests

- Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini)
- Remove metadata tests that test framework internals rather than functionality
- Rename complex scenario test to be more descriptive of what it tests

* [refactor] Rename 'border' to 'spacing' for semantic accuracy

- Change border_width/border_color to spacing_width/spacing_color in API
- Update all tests to use spacing terminology
- Update comments and variable names throughout
- More accurately describes the gap/separator between images
2025-06-01 04:28:52 -04:00
filtered
456abad834 Update frontend to 1.21 (#8366) 2025-06-01 01:10:04 -04:00
comfyanonymous
19e45e9b0e Make it easier to pass lists of tensors to models. (#8358) 2025-05-31 20:00:20 -04:00
ComfyUI Wiki
97f23b81f3 Bump template to 0.1.23 (#8353)
Correct some error settings in VACE
2025-05-30 23:05:42 -07:00
drhead
08b7cc7506 use fused multiply-add pointwise ops in chroma (#8279) 2025-05-30 18:09:54 -04:00
BennyKok
6c319cbb4e fix: custom comfy-api-base works with subpath (#8332) 2025-05-30 17:51:28 -04:00
Chenlei Hu
df1aebe52e Remove huchenlei from CODEOWNERS (#8350) 2025-05-30 17:27:52 -04:00
comfyanonymous
704fc78854 Put ROCm version in tuple to make it easier to enable stuff based on it. (#8348) 2025-05-30 15:41:02 -04:00
JettHu
1d9fee79fd Add node for regex replace(sub) operation (#8340)
* Add node for regex replace(sub) operation

* Apply suggestions from code review

add tooltips

Co-authored-by: Christian Byrne <abolkonsky.rem@gmail.com>

* Fix indentation

---------

Co-authored-by: Christian Byrne <abolkonsky.rem@gmail.com>
2025-05-30 15:08:59 -04:00
Jedrzej Kosinski
aeba0b3a26 Reduce code duplication for [pro] and [max], rename Pro and Max to [pro] and [max] to be consistent with other BFL nodes, make default seed for Kontext nodes be 1234. since 0 is interpreted by API as 'choose random seed' (#8337) 2025-05-29 17:14:27 -04:00
121 changed files with 160236 additions and 1333 deletions

View File

@@ -4,6 +4,9 @@ if you have a NVIDIA gpu:
run_nvidia_gpu.bat
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
run_nvidia_gpu_fast_fp16_accumulation.bat
To run it in slow CPU mode:

View File

@@ -15,6 +15,14 @@ body:
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Expected Behavior

View File

@@ -11,6 +11,14 @@ body:
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Your question

View File

@@ -0,0 +1,40 @@
name: Check for Windows Line Endings
on:
pull_request:
branches: ['*'] # Trigger on all pull requests to any branch
jobs:
check-line-endings:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history to compare changes
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
# Flag to track if CRLF is found
CRLF_FOUND=false
# Loop through each changed file
for FILE in $CHANGED_FILES; do
# Check if the file exists and is a text file
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
# Check for CRLF line endings
if grep -UP '\r$' "$FILE"; then
echo "Error: Windows line endings (CRLF) detected in $FILE"
CRLF_FOUND=true
fi
fi
done
# Exit with error if CRLF was found
if [ "$CRLF_FOUND" = true ]; then
exit 1
fi

108
.github/workflows/release-webhook.yml vendored Normal file
View File

@@ -0,0 +1,108 @@
name: Release Webhook
on:
release:
types: [published]
jobs:
send-webhook:
runs-on: ubuntu-latest
steps:
- name: Send release webhook
env:
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
run: |
# Generate UUID for delivery ID
DELIVERY_ID=$(uuidgen)
HOOK_ID="release-webhook-$(date +%s)"
# Create webhook payload matching GitHub release webhook format
PAYLOAD=$(cat <<EOF
{
"action": "published",
"release": {
"id": ${{ github.event.release.id }},
"node_id": "${{ github.event.release.node_id }}",
"url": "${{ github.event.release.url }}",
"html_url": "${{ github.event.release.html_url }}",
"assets_url": "${{ github.event.release.assets_url }}",
"upload_url": "${{ github.event.release.upload_url }}",
"tag_name": "${{ github.event.release.tag_name }}",
"target_commitish": "${{ github.event.release.target_commitish }}",
"name": ${{ toJSON(github.event.release.name) }},
"body": ${{ toJSON(github.event.release.body) }},
"draft": ${{ github.event.release.draft }},
"prerelease": ${{ github.event.release.prerelease }},
"created_at": "${{ github.event.release.created_at }}",
"published_at": "${{ github.event.release.published_at }}",
"author": {
"login": "${{ github.event.release.author.login }}",
"id": ${{ github.event.release.author.id }},
"node_id": "${{ github.event.release.author.node_id }}",
"avatar_url": "${{ github.event.release.author.avatar_url }}",
"url": "${{ github.event.release.author.url }}",
"html_url": "${{ github.event.release.author.html_url }}",
"type": "${{ github.event.release.author.type }}",
"site_admin": ${{ github.event.release.author.site_admin }}
},
"tarball_url": "${{ github.event.release.tarball_url }}",
"zipball_url": "${{ github.event.release.zipball_url }}",
"assets": ${{ toJSON(github.event.release.assets) }}
},
"repository": {
"id": ${{ github.event.repository.id }},
"node_id": "${{ github.event.repository.node_id }}",
"name": "${{ github.event.repository.name }}",
"full_name": "${{ github.event.repository.full_name }}",
"private": ${{ github.event.repository.private }},
"owner": {
"login": "${{ github.event.repository.owner.login }}",
"id": ${{ github.event.repository.owner.id }},
"node_id": "${{ github.event.repository.owner.node_id }}",
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
"url": "${{ github.event.repository.owner.url }}",
"html_url": "${{ github.event.repository.owner.html_url }}",
"type": "${{ github.event.repository.owner.type }}",
"site_admin": ${{ github.event.repository.owner.site_admin }}
},
"html_url": "${{ github.event.repository.html_url }}",
"clone_url": "${{ github.event.repository.clone_url }}",
"git_url": "${{ github.event.repository.git_url }}",
"ssh_url": "${{ github.event.repository.ssh_url }}",
"url": "${{ github.event.repository.url }}",
"created_at": "${{ github.event.repository.created_at }}",
"updated_at": "${{ github.event.repository.updated_at }}",
"pushed_at": "${{ github.event.repository.pushed_at }}",
"default_branch": "${{ github.event.repository.default_branch }}",
"fork": ${{ github.event.repository.fork }}
},
"sender": {
"login": "${{ github.event.sender.login }}",
"id": ${{ github.event.sender.id }},
"node_id": "${{ github.event.sender.node_id }}",
"avatar_url": "${{ github.event.sender.avatar_url }}",
"url": "${{ github.event.sender.url }}",
"html_url": "${{ github.event.sender.html_url }}",
"type": "${{ github.event.sender.type }}",
"site_admin": ${{ github.event.sender.site_admin }}
}
}
EOF
)
# Generate HMAC-SHA256 signature
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
# Send webhook with required headers
curl -X POST "$WEBHOOK_URL" \
-H "Content-Type: application/json" \
-H "X-GitHub-Event: release" \
-H "X-GitHub-Delivery: $DELIVERY_ID" \
-H "X-GitHub-Hook-ID: $HOOK_ID" \
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
-d "$PAYLOAD" \
--fail --silent --show-error
echo "✅ Release webhook sent successfully"

View File

@@ -102,5 +102,4 @@ jobs:
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
prerelease: true
make_latest: false
draft: true

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "128"
default: "129"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "2"
default: "5"
# push:
# branches:
# - master
@@ -53,6 +53,8 @@ jobs:
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd

View File

@@ -5,20 +5,20 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne

View File

@@ -6,6 +6,7 @@
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Twitter][twitter-shield]][twitter-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
@@ -20,6 +21,8 @@
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
@@ -62,12 +65,16 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -79,6 +86,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
@@ -94,8 +102,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -170,10 +178,6 @@ If you have trouble extracting it, right click the file -> properties -> unblock
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
## Jupyter Notebook
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
@@ -235,7 +239,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
#### Troubleshooting
@@ -268,6 +272,8 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs

84
alembic.ini Normal file
View File

@@ -0,0 +1,84 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME

4
alembic_db/README.md Normal file
View File

@@ -0,0 +1,4 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

64
alembic_db/env.py Normal file
View File

@@ -0,0 +1,64 @@
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic_db/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

112
app/database/db.py Normal file
View File

@@ -0,0 +1,112 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

14
app/database/models.py Normal file
View File

@@ -0,0 +1,14 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

View File

@@ -16,26 +16,17 @@ from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
{get_missing_requirements_message()}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
@@ -48,7 +39,7 @@ def check_frontend_version():
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
with open(requirements_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning(
@@ -121,9 +112,22 @@ class FrontEndProvider:
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
@cached_property
def latest_prerelease(self) -> Release:
"""Get the latest pre-release version - even if it's older than the latest release"""
release = [release for release in self.all_releases if release["prerelease"]]
if not release:
raise ValueError("No pre-releases found")
# GitHub returns releases in reverse chronological order, so first is latest
return release[0]
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
elif version == "prerelease":
return self.latest_prerelease
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
@@ -205,6 +209,19 @@ comfyui-workflow-templates is not installed.
""".strip()
)
@classmethod
def embedded_docs_path(cls) -> str:
"""Get the path to embedded documentation"""
try:
import comfyui_embedded_docs
return str(
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
return None
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
@@ -217,7 +234,7 @@ comfyui-workflow-templates is not installed.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")

View File

@@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -151,6 +152,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
@@ -203,6 +205,11 @@ parser.add_argument(
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"

View File

@@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@@ -390,8 +390,9 @@ class ControlLora(ControlNet):
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
if (k not in {"lora_controlnet"}):
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)

View File

@@ -1,55 +1,10 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):

View File

@@ -0,0 +1,121 @@
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
# Codebase ref: https://github.com/scxue/SA-Solver
import math
from typing import Union, Callable
import torch
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
Integral of exp((1 + tau^2) * x) * x^p dx
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
with base case p=0 where integral equals product_terms[0].
where
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
Return coefficients used by the SA-Solver in data prediction mode.
Args:
s: Start time s.
t: End time t.
solver_order: Current order of the solver.
tau_t: Stochastic strength parameter in the SDE.
Returns:
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order1, shape (solver_order,).
"""
tau_mul = 1 + tau_t ** 2
h = t - s
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
# product_terms after factoring out exp((1 + tau^2) * t)
# Includes (1 + tau^2) factor from outside the integral
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
# Lower triangular recursive coefficient matrix
# Accumulates recursive coefficients based on p / (1 + tau^2)
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
log_factorial = (p + 1).lgamma()
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
if tau_t > 0:
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
return recursive_coeff_mat @ product_terms_factored
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
tau_mul = 1 + tau_t ** 2
h = lambda_t - lambda_s
alpha_t = sigma_next * lambda_t.exp()
if is_corrector_step:
# Simplified 1-step (order-2) corrector
b_1 = alpha_t * (0.5 * tau_mul * h)
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
else:
# Simplified 2-step predictor
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
return torch.stack([b_2, b_1])
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
The solver order corresponds to the number of input lambdas (half-logSNR points).
Args:
sigma_next: Sigma at end time t.
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
lambda_s: Lambda at start time s.
lambda_t: Lambda at end time t.
tau_t: Stochastic strength parameter in the SDE.
simple_order_2: Whether to enable the simple order-2 scheme.
is_corrector_step: Flag for corrector step in simple order-2 mode.
Returns:
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
"""
num_timesteps = curr_lambdas.shape[0]
if simple_order_2 and num_timesteps == 2:
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
# Compute coefficients by solving a linear system from Lagrange basis interpolation
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
# = sigma_t * exp(lambda_t) = alpha_t
# exp((1 + tau^2) * lambda_t) is extracted from the integral
alpha_t = sigma_next * lambda_t.exp()
return alpha_t * lagrange_integrals
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
"""Return a function that controls the stochasticity of SA-Solver.
When eta = 0, SA-Solver runs as ODE. The official approach uses
time t to determine the SDE interval, while here we use sigma instead.
See:
https://github.com/scxue/SA-Solver/blob/main/README.md
"""
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
if eta <= 0:
return 0.0 # ODE
if isinstance(sigma, torch.Tensor):
sigma = sigma.item()
return eta if start_sigma >= sigma >= end_sigma else 0.0
return tau_func

View File

@@ -1,4 +1,5 @@
import math
from functools import partial
from scipy import integrate
import torch
@@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
@@ -142,6 +144,33 @@ class BrownianTreeNoiseSampler:
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(model_sampling, comfy.model_sampling.CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -384,9 +413,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@@ -682,6 +715,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -693,38 +727,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
# Denoising step
x = denoised
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
alpha_s = sigmas[i] * lambda_s.exp()
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
lambda_s_1_ = sd.log().neg()
h_ = lambda_s_1_ - lambda_s
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
if eta > 0 and s_noise > 0:
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
lambda_t_ = sd.log().neg()
h_ = lambda_t_ - lambda_s
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
return x
@@ -753,6 +798,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
@@ -768,9 +814,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None
h_last = None
h = None
h, h_last = None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -781,26 +830,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -814,6 +866,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@@ -825,13 +881,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
# Denoising step
x = denoised
else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
@@ -840,20 +899,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d
x = x + (alpha_t * phi_2) * d
if eta:
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
@@ -863,6 +924,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -872,6 +934,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
@@ -1009,7 +1072,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
@@ -1027,6 +1092,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
@@ -1050,7 +1116,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
@@ -1090,6 +1158,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@torch.no_grad()
@@ -1140,6 +1209,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
@@ -1346,6 +1416,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1372,31 +1443,32 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# Euler method
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt
else:
# Gradient estimation
if cfg_pp:
if i >= 1:
# Gradient estimation
d_bar = (ge_gamma - 1) * (d - old_d)
x = denoised + d * sigmas[i + 1] + d_bar * dt
else:
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
@@ -1404,12 +1476,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
def default_er_sde_noise_scaler(x):
return x * ((x ** 0.3).exp() + 10.0)
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
old_denoised = None
old_denoised_d = None
@@ -1420,41 +1498,45 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigmas[i] / er_lambda_s
alpha_t = sigmas[i + 1] / er_lambda_t
r_alpha = alpha_t / alpha_s
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
# Stage 2
s = torch.sum(1 / scaled_pos) * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
if stage_used >= 3:
# Stage 3
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
old_denoised_d = denoised_d
if s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1462,6 +1544,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
@@ -1469,80 +1556,206 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
s = t + r * h
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s = s.neg().exp()
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x
@torch.no_grad()
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
if tau_func is None:
# Use default interval for stochastic sampling
start_sigma = model_sampling.percent_to_sigma(0.2)
end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
max_used_order = max(predictor_order, corrector_order)
x_pred = x # x: current state, x_pred: predicted next state
h = 0.0
tau_t = 0.0
noise = 0.0
pred_list = []
# Lower order near the end to improve stability
lower_order_to_end = sigmas[-1].item() == 0
for i in trange(len(sigmas) - 1, disable=disable):
# Evaluation
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
pred_list.append(denoised)
pred_list = pred_list[-max_used_order:]
predictor_order_used = min(predictor_order, len(pred_list))
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
corrector_order_used = 0
else:
corrector_order_used = min(corrector_order, len(pred_list))
if lower_order_to_end:
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
# Corrector
if corrector_order_used == 0:
# Update by the predicted state
x = x_pred
else:
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i],
curr_lambdas,
lambdas[i - 1],
lambdas[i],
tau_t,
simple_order_2,
is_corrector_step=True,
)
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
if tau_t > 0 and s_noise > 0:
# The noise from the previous predictor step
x = x + noise
if use_pece:
# Evaluate the corrected state
denoised = model(x, sigmas[i] * s_in, **extra_args)
pred_list[-1] = denoised
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i + 1],
curr_lambdas,
lambdas[i],
lambdas[i + 1],
tau_t,
simple_order_2,
is_corrector_step=False,
)
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
h = lambdas[i + 1] - lambdas[i]
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)

View File

@@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@@ -178,6 +176,6 @@ class LastLayer(nn.Module):
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
x = self.linear(x)
return x

View File

@@ -254,13 +254,12 @@ class Chroma(nn.Module):
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
@@ -268,4 +267,4 @@ class Chroma(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]

View File

@@ -26,16 +26,6 @@ from torch import nn
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()

View File

@@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
enable_fps_modulation: bool = True,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
self.enable_fps_modulation = enable_fps_modulation
dim = head_dim
dim_h = dim // 6 * 2
@@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
B, T, H, W, _ = B_T_H_W_C
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
if fps is None or self.enable_fps_modulation is False: # image case
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)

View File

@@ -0,0 +1,864 @@
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
import logging
from typing import Callable, Optional, Tuple
import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
# ---------------------- Feed Forward Network -----------------------
class GPT2FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
super().__init__()
self.activation = nn.GELU()
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
self._layer_id = None
self._dim = d_model
self._hidden_dim = d_ff
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
attention, and rearranges the output back to the original format.
The input tensor names use the following dimension conventions:
- B: batch size
- S: sequence length
- H: number of attention heads
- D: head dimension
Args:
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
Returns:
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):
"""
A flexible attention module supporting both self-attention and cross-attention mechanisms.
This module implements a multi-head attention layer that can operate in either self-attention
or cross-attention mode. The mode is determined by whether a context dimension is provided.
The implementation uses scaled dot-product attention and supports optional bias terms and
dropout regularization.
Args:
query_dim (int): The dimensionality of the query vectors.
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
If None, the module operates in self-attention mode using query_dim. Default: None
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
head_dim (int, optional): The dimension of each attention head. Default: 64
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
Examples:
>>> # Self-attention with 512 dimensions and 8 heads
>>> self_attn = Attention(query_dim=512)
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
>>> out = self_attn(x) # (32, 16, 512)
>>> # Cross-attention
>>> cross_attn = Attention(query_dim=512, context_dim=256)
>>> query = torch.randn(32, 16, 512)
>>> context = torch.randn(32, 8, 256)
>>> out = cross_attn(query, context) # (32, 16, 512)
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
n_heads: int = 8,
head_dim: int = 64,
dropout: float = 0.0,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
logging.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{n_heads} heads with a dimension of {head_dim}."
)
self.is_selfattn = context_dim is None # self attention
context_dim = query_dim if context_dim is None else context_dim
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.v_norm = nn.Identity()
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
def compute_qkv(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_proj(x)
context = x if context is None else context
k = self.k_proj(context)
v = self.v_proj(context)
q, k, v = map(
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
(q, k, v),
)
def apply_norm_and_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_norm(q)
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
class Timesteps(nn.Module):
def __init__(self, num_channels: int):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
timesteps = timesteps_B_T.flatten().float()
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.in_dim = in_features
self.out_dim = out_features
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_T_3D = emb
emb_B_T_D = sample
else:
adaln_lora_B_T_3D = None
emb_B_T_D = emb
return emb_B_T_D, adaln_lora_B_T_3D
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size: int,
temporal_patch_size: int,
in_channels: int = 3,
out_channels: int = 768,
device=None, dtype=None, operations=None
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert (
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size: int,
spatial_patch_size: int,
temporal_patch_size: int,
out_channels: int,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
if use_adaln_lora:
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
)
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
scale_B_T_D, "b t d -> b t 1 1 d"
)
def _fn(
_x_B_T_H_W_D: torch.Tensor,
_norm_layer: nn.Module,
_scale_B_T_1_1_D: torch.Tensor,
_shift_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
return x_B_T_H_W_O
class Block(nn.Module):
"""
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
The block applies the following sequence:
1. Self-attention with AdaLN modulation
2. Cross-attention with AdaLN modulation
3. MLP with AdaLN modulation
Each component uses skip connections and layer normalization.
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.x_dim = x_dim
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
)
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
self.use_adaln_lora = use_adaln_lora
if self.use_adaln_lora:
self.adaln_modulation_self_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_cross_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_mlp = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
B, T, H, W, D = x_B_T_H_W_D.shape
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_self_attn,
scale_self_attn_B_T_1_1_D,
shift_self_attn_B_T_1_1_D,
)
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
return _result_B_T_H_W_D
result_B_T_H_W_D = _x_fn(
x_B_T_H_W_D,
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_mlp,
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
return x_B_T_H_W_D
class MiniTrainDIT(nn.Module):
"""
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
min_fps (int): Minimum frames per second.
max_fps (int): Maximum frames per second.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: int, # tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
# cross attention settings
crossattn_emb_channels: int = 1024,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
min_fps: int = 1,
max_fps: int = 30,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.min_fps = min_fps
self.max_fps = max_fps
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.build_pos_embed(device=device, dtype=dtype)
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.Sequential(
Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
)
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
device=device, dtype=dtype, operations=operations,
)
self.blocks = nn.ModuleList(
[
Block(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_blocks)
]
)
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
def build_pos_embed(self, device=None, dtype=None) -> None:
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
max_fps=self.max_fps,
min_fps=self.min_fps,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
enable_fps_modulation=self.rope_enable_fps_modulation,
device=device,
)
self.pos_embedder = cls_type(
**kwargs, # type: ignore
)
if self.extra_per_block_abs_pos_emb:
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs, # type: ignore
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is None:
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
else:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
x_B_C_Tt_Hp_Wp = rearrange(
x_B_T_H_W_M,
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
t=self.patch_temporal,
)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
x_B_C_T_H_W = x
timesteps_B_T = timesteps
crossattn_emb = context
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
"""
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x_B_C_T_H_W,
fps=fps,
padding_mask=padding_mask,
)
if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
# for logging purpose
affline_scale_log_info = {}
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
self.affline_scale_log_info = affline_scale_log_info
self.affline_emb = t_embedding_B_T_D
self.crossattn_emb = crossattn_emb
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
block_kwargs = {
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
**block_kwargs,
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -121,6 +121,11 @@ class ControlNetFlux(Flux):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
else:
y = y[:, :self.params.vec_in_dim]
# running on sequences img
img = self.img_in(img)
@@ -174,7 +179,7 @@ class ControlNetFlux(Flux):
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))

View File

@@ -118,7 +118,7 @@ class Modulation(nn.Module):
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
return torch.addcmul(m_add, tensor, m_mult)
else:
return tensor * m_mult
else:

View File

@@ -101,6 +101,10 @@ class Flux(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -155,6 +159,9 @@ class Flux(nn.Module):
if add is not None:
img += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
@@ -188,20 +195,50 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]

View File

@@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
)
self.per_channel_statistics = processor()

View File

@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
def __init__(self, sample: bool = False):
super().__init__()
self.sample = sample
@@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
return z, None
class AbstractAutoencoder(torch.nn.Module):

View File

@@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
@@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x
return x

View File

@@ -31,7 +31,7 @@ def dynamic_slice(
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
return x[slicing]
class AttnChunk(NamedTuple):

View File

@@ -0,0 +1,469 @@
# Original code: https://github.com/VectorSpaceLab/OmniGen2
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.act = nn.SiLU()
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class LuminaRMSNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x) * (1 + emb)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
self.caption_embedder = nn.Sequential(
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
class Attention(nn.Module):
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = query.view(batch_size, -1, self.heads, self.dim_head)
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
class OmniGen2TransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.attn = Attention(
query_dim=dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
dtype=dtype, device=device, operations=operations,
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
dtype=dtype, device=device, operations=operations
)
if modulation:
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
else:
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
p = self.patch_size
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
max_seq_len = max(seq_lengths)
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
pe_shift = cap_seq_len
pe_shift_len = cap_seq_len
if ref_img_sizes[i] is not None:
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
H, W = ref_img_size
ref_H_tokens, ref_W_tokens = H // p, W // p
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
pe_shift += max(ref_H_tokens, ref_W_tokens)
pe_shift_len += ref_img_len
H, W = img_sizes[i]
H_tokens, W_tokens = H // p, W // p
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = encoder_seq_len
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
ref_img_freqs_cis_shape = list(freqs_cis.shape)
ref_img_freqs_cis_shape[1] = max_ref_img_len
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
class OmniGen2Transformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.layers = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.patch_size
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
flat_ref_img_hidden_states = None
if ref_image_hidden_states is not None:
imgs = []
for ref_img in ref_image_hidden_states:
B, C, H, W = ref_img.size()
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
imgs.append(ref_img)
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
img = hidden_states
B, C, H, W = img.size()
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
return (
flat_hidden_states, flat_ref_img_hidden_states,
None, None,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if ref_image_hidden_states is not None:
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@@ -1,256 +1,256 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@@ -34,12 +34,14 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.model_management
import comfy.patcher_extension
@@ -48,6 +50,7 @@ import comfy.ops
from enum import Enum
from . import utils
import comfy.latent_formats
import comfy.model_sampling
import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -63,38 +66,39 @@ class ModelType(Enum):
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
FLOW_COSMOS = 10
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
s = comfy.model_sampling.ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
c = comfy.model_sampling.EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
c = comfy.model_sampling.V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
c = comfy.model_sampling.EPS
s = comfy.model_sampling.StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
c = comfy.model_sampling.EDM
s = comfy.model_sampling.ModelSamplingContinuousEDM
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION
s = ModelSamplingContinuousV
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousV
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSampling(s, c):
pass
@@ -102,6 +106,13 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
return extra
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
@@ -165,9 +176,14 @@ class BaseModel(torch.nn.Module):
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = convert_tensor(extra, dtype)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
extra = ex
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
@@ -800,6 +816,7 @@ class PixArt(BaseModel):
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
self.memory_usage_factor_conds = ("ref_latents",)
def concat_cond(self, **kwargs):
try:
@@ -860,8 +877,23 @@ class Flux(BaseModel):
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@@ -986,6 +1018,45 @@ class CosmosVideo(BaseModel):
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class CosmosPredict2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
self.image_to_video = image_to_video
if self.image_to_video:
self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
if denoise_mask.ndim <= 4:
return timestep
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
@@ -1176,3 +1247,33 @@ class ACEStep(BaseModel):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out

View File

@@ -407,6 +407,78 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = True
dit_config["pos_emb_interpolation"] = "crop"
dit_config["min_fps"] = 1
dit_config["max_fps"] = 30
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 2048:
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 16
elif dit_config["model_channels"] == 5120:
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
if dit_config["in_channels"] == 16:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17: # img to video
if dit_config["model_channels"] == 2048:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["model_channels"] == 5120:
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
dit_config["extra_t_extrapolation_ratio"] = 1.0
dit_config["rope_enable_fps_modulation"] = False
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
dit_config["axes_dim_rope"] = [40, 40, 40]
dit_config["axes_lens"] = [1024, 1664, 1664]
dit_config["ffn_dim_multiplier"] = None
dit_config["hidden_size"] = 2520
dit_config["in_channels"] = 16
dit_config["multiple_of"] = 256
dit_config["norm_eps"] = 1e-05
dit_config["num_attention_heads"] = 21
dit_config["num_kv_heads"] = 7
dit_config["num_layers"] = 32
dit_config["num_refiner_layers"] = 2
dit_config["out_channels"] = None
dit_config["patch_size"] = 2
dit_config["text_feat_dim"] = 2048
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -295,14 +295,24 @@ except:
pass
SUPPORT_FP8_OPS = args.supports_fp8_compute
try:
if is_amd():
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
except:
pass
@@ -323,7 +333,7 @@ except:
pass
try:
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@@ -1042,7 +1052,7 @@ def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if is_nvidia(): #pytorch flash attention only works on Nvidia
if is_nvidia():
return True
if is_intel_xpu():
return True
@@ -1058,7 +1068,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
upcast = True
if upcast:
@@ -1257,7 +1267,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if args.supports_fp8_compute:
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
@@ -1271,15 +1281,22 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
if torch_version_numeric < (2, 4):
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -17,23 +17,26 @@
"""
from __future__ import annotations
from typing import Optional, Callable
import torch
import collections
import copy
import inspect
import logging
import uuid
import collections
import math
import uuid
from typing import Callable, Optional
import torch
import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data):
crc = 0xFFFFFFFF
@@ -376,6 +379,9 @@ class ModelPatcher:
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function

View File

@@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
@@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
if percent >= 1.0:
return 0.0
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
def timestep(self, sigma):
return sigma / (sigma + 1)
def sigma(self, timestep):
sigma_max = self.sigma_max
if timestep >= (sigma_max / (sigma_max + 1)):
return sigma_max
return timestep / (1 - timestep)

View File

@@ -336,9 +336,12 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

View File

@@ -373,7 +373,11 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
uncond_ = uncond
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
if "sampler_calc_cond_batch_function" in model_options:
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = model_options["sampler_calc_cond_batch_function"](args)
else:
out = calc_cond_batch(model, conds, x, timestep, model_options)
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
@@ -716,7 +720,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@@ -1039,13 +1043,13 @@ class SchedulerHandler(NamedTuple):
use_ms: bool = True
SCHEDULER_HANDLERS = {
"normal": SchedulerHandler(normal_scheduler),
"simple": SchedulerHandler(simple_scheduler),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
"simple": SchedulerHandler(simple_scheduler),
"ddim_uniform": SchedulerHandler(ddim_scheduler),
"beta": SchedulerHandler(beta_scheduler),
"normal": SchedulerHandler(normal_scheduler),
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
}

View File

@@ -18,6 +18,7 @@ import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
import os
import comfy.utils
@@ -44,6 +45,7 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.model_patcher
import comfy.lora
@@ -754,6 +756,7 @@ class CLIPType(Enum):
HIDREAM = 14
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -773,6 +776,7 @@ class TEModel(Enum):
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -793,6 +797,8 @@ def detect_te_model(sd):
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
return TEModel.QWEN25_3B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -894,6 +900,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -969,6 +978,12 @@ def load_gligen(ckpt_path):
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
if 'lora' in filename.lower():
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
return ""
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@@ -997,7 +1012,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
@@ -1081,7 +1096,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.
The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files
@@ -1139,7 +1175,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in unet: {}".format(left_over))
logging.info("left over keys in diffusion model: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
@@ -1147,8 +1183,8 @@ def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
return model
def load_unet(unet_path, dtype=None):

View File

@@ -462,7 +462,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
@@ -482,7 +482,8 @@ class SDTokenizer:
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if has_end_token:
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token

View File

@@ -18,7 +18,7 @@
"single_word": false
},
"errors": "replace",
"model_max_length": 77,
"model_max_length": 8192,
"name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",

View File

@@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
from . import supported_models_base
from . import latent_formats
@@ -908,6 +909,48 @@ class CosmosI2V(CosmosT2V):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
class CosmosT2IPredict2(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 16,
}
sampling_settings = {
"sigma_data": 1.0,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 17,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
return out
class Lumina2(supported_models_base.BASE):
unet_config = {
"image_model": "lumina2",
@@ -1139,6 +1182,41 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
class Omnigen2(supported_models_base.BASE):
unet_config = {
"image_model": "omnigen2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Omnigen2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
models += [SVD_img2vid]

View File

@@ -24,6 +24,24 @@ class Llama2Config:
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 11008
num_hidden_layers: int = 36
num_attention_heads: int = 16
num_key_value_heads: int = 2
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
@@ -40,6 +58,7 @@ class Gemma2_2B_Config:
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -98,9 +117,9 @@ class Attention(nn.Module):
self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
@@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):

View File

@@ -0,0 +1,44 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_

View File

@@ -1,42 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<|img|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151666": {
"content": "<|endofimg|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151667": {
"content": "<|meta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151668": {
"content": "<|endofmeta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"extra_special_tokens": {},
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"processor_class": "Qwen2_5_VLProcessor",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

File diff suppressed because one or more lines are too long

View File

@@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
)
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
return values.contiguous()
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
q = self.q(x)

View File

@@ -31,6 +31,7 @@ from einops import rearrange
from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@@ -58,7 +59,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
except Exception as e:
@@ -77,6 +81,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
@@ -997,11 +1002,12 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id
def update_absolute(self, value, total=None, preview=None):
if total is not None:
@@ -1010,7 +1016,7 @@ class ProgressBar:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
self.hook(self.current, self.total, preview, node_id=self.node_id)
def update(self, value):
self.update_absolute(self.current + value)

View File

@@ -1,4 +1,4 @@
from .base import WeightAdapterBase
from .base import WeightAdapterBase, WeightAdapterTrainBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
@@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
] + [a.__name__ for a in adapters]

View File

@@ -12,12 +12,20 @@ class WeightAdapterBase:
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
@classmethod
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
weight: The original weight tensor to be modified.
*args: Additional arguments for configuration, such as rank, alpha etc.
"""
raise NotImplementedError
def calculate_weight(
self,
weight,
@@ -33,10 +41,22 @@ class WeightAdapterBase:
class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
def __init__(self):
super().__init__()
# [TODO] Collaborate with LoRA training PR #7032
def __call__(self, w):
"""
w: The original weight tensor to be modified.
"""
raise NotImplementedError
def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")
def move_to(self, device):
self.to(device)
return self.passive_memory_usage()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
@@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def tucker_weight_from_conv(up, down, mid):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)

View File

@@ -3,7 +3,56 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
pad_tensor_to_shape,
tucker_weight_from_conv,
)
class LoraDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
mat1, mat2, alpha, mid, dora_scale, reshape = weights
out_dim, rank = mat1.shape[0], mat1.shape[1]
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
self.lora_down = layer(in_dim, rank, bias=False)
self.lora_up.weight.data.copy_(mat1)
self.lora_down.weight.data.copy_(mat2)
if mid is not None:
self.lora_mid = layer(mid, rank, bias=False)
self.lora_mid.weight.data.copy_(mid)
else:
self.lora_mid = None
self.rank = rank
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
if self.lora_mid is None:
diff = self.lora_up.weight @ self.lora_down.weight
else:
diff = tucker_weight_from_conv(
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
)
scale = self.alpha / self.rank
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoRAAdapter(WeightAdapterBase):
@@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
def to_train(self):
return LoraDiff(self.weights)
@classmethod
def load(
cls,

View File

@@ -0,0 +1,69 @@
"""
Feature flags module for ComfyUI WebSocket protocol negotiation.
This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
}
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
) -> Any:
"""
Get a feature flag value for a specific connection.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
default: Default value if feature not found
Returns:
Feature value or default if not found
"""
if sid not in sockets_metadata:
return default
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
"""
Check if a connection supports a specific feature.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
Returns:
Boolean indicating if feature is supported
"""
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.
Returns:
Dictionary of server feature flags
"""
return SERVER_FEATURE_FLAGS.copy()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
import io
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
@@ -31,6 +32,22 @@ class VideoInput(ABC):
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.
Returns:
Either a file path (str) or a BytesIO object that can be opened with av.
Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:

View File

@@ -64,6 +64,15 @@ class VideoFromFile(VideoInput):
"""
self.__file = file
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.

View File

@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-05-19T21:38:55+00:00
# timestamp: 2025-07-06T09:47:31+00:00
from __future__ import annotations
@@ -1355,6 +1355,158 @@ class ModelResponseProperties(BaseModel):
)
class Keyframes(BaseModel):
image_url: Optional[str] = None
class MoonvalleyPromptResponse(BaseModel):
error: Optional[Dict[str, Any]] = None
frame_conditioning: Optional[Dict[str, Any]] = None
id: Optional[str] = None
inference_params: Optional[Dict[str, Any]] = None
meta: Optional[Dict[str, Any]] = None
model_params: Optional[Dict[str, Any]] = None
output_url: Optional[str] = None
prompt_text: Optional[str] = None
status: Optional[str] = None
class MoonvalleyTextToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
None, description='Number of cooldown steps (calculated based on num_frames)'
)
fps: Optional[int] = Field(
24, description='Frames per second of the generated video'
)
guidance_scale: Optional[float] = Field(
12.5, description='Guidance scale for generation control'
)
height: Optional[int] = Field(
1080, description='Height of the generated video in pixels'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
None, description='Number of warmup steps (calculated based on num_frames)'
)
width: Optional[int] = Field(
1920, description='Width of the generated video in pixels'
)
class MoonvalleyTextToVideoRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class MoonvalleyUploadFileRequest(BaseModel):
file: Optional[StrictBytes] = None
class MoonvalleyUploadFileResponse(BaseModel):
access_url: Optional[str] = None
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
None, description='Number of cooldown steps (calculated based on num_frames)'
)
guidance_scale: Optional[float] = Field(
12.5, description='Guidance scale for generation control'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
None, description='Number of warmup steps (calculated based on num_frames)'
)
class ControlType(str, Enum):
motion_control = 'motion_control'
pose_control = 'pose_control'
class MoonvalleyVideoToVideoRequest(BaseModel):
control_type: ControlType = Field(
..., description='Supported types for video control'
)
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
prompt_text: str = Field(..., description='Describes the video to generate')
video_url: str = Field(..., description='Url to control video')
webhook_url: Optional[str] = Field(
None, description='Optional webhook URL for notifications'
)
class Moderation(str, Enum):
low = 'low'
auto = 'auto'
@@ -3107,6 +3259,23 @@ class LumaUpscaleVideoGenerationRequest(BaseModel):
resolution: Optional[LumaVideoModelOutputResolution] = None
class MoonvalleyImageToVideoRequest(MoonvalleyTextToVideoRequest):
keyframes: Optional[Dict[str, Keyframes]] = None
class MoonvalleyResizeVideoRequest(MoonvalleyVideoToVideoRequest):
frame_position: Optional[List[int]] = Field(None, max_length=2, min_length=2)
frame_resolution: Optional[List[int]] = Field(None, max_length=2, min_length=2)
scale: Optional[List[int]] = Field(None, max_length=2, min_length=2)
class MoonvalleyTextToImageRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]):
root: Union[OutputTextContent, OutputAudioContent]

View File

@@ -125,22 +125,6 @@ class BFLFluxKontextProGenerateRequest(BaseModel):
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxKontextMaxGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')

View File

@@ -327,7 +327,9 @@ class ApiClient:
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
url = urljoin(self.base_url, path)
# Use urljoin but ensure path is relative to avoid absolute path behavior
relative_path = path.lstrip('/')
url = urljoin(self.base_url, relative_path)
self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()

View File

@@ -272,7 +272,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
class FluxKontextProImageNode(ComfyNodeABC):
"""
Edits images using Flux.1 Kontext Pro via api based on prompt and resolution.
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
@@ -321,7 +321,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
"seed": (
IO.INT,
{
"default": 0,
"default": 1234,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
@@ -346,26 +346,14 @@ class FluxKontextProImageNode(ComfyNodeABC):
},
}
@classmethod
def VALIDATE_INPUTS(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
def api_call(
self,
prompt: str,
@@ -378,11 +366,18 @@ class FluxKontextProImageNode(ComfyNodeABC):
unique_id: Union[str, None] = None,
**kwargs,
):
aspect_ratio = validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
)
if input_image is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-kontext-pro/generate",
path=self.BFL_PATH,
method=HttpMethod.POST,
request_model=BFLFluxKontextProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
@@ -393,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
guidance=round(guidance, 1),
steps=steps,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
),
aspect_ratio=aspect_ratio,
input_image=(
input_image
if input_image is None
@@ -411,146 +400,15 @@ class FluxKontextProImageNode(ComfyNodeABC):
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
class FluxKontextMaxImageNode(ComfyNodeABC):
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext Max via api based on prompt and resolution.
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
DESCRIPTION = cleandoc(__doc__ or "")
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation - specify what and how to edit.",
},
),
"aspect_ratio": (
IO.STRING,
{
"default": "16:9",
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
},
),
"guidance": (
IO.FLOAT,
{
"default": 3.0,
"min": 0.1,
"max": 99.0,
"step": 0.1,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 1,
"max": 150,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
},
"optional": {
"input_image": (IO.IMAGE,),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
def VALIDATE_INPUTS(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
prompt: str,
aspect_ratio: str,
guidance: float,
steps: int,
input_image: Optional[torch.Tensor]=None,
seed=0,
prompt_upsampling=False,
unique_id: Union[str, None] = None,
**kwargs,
):
if input_image is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-kontext-max/generate",
method=HttpMethod.POST,
request_model=BFLFluxKontextProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxKontextProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
guidance=round(guidance, 1),
steps=steps,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
),
input_image=(
input_image
if input_image is None
else convert_image_to_base64(input_image)
)
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
class FluxProImageNode(ComfyNodeABC):
"""
@@ -1208,8 +1066,8 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
# "FluxProImageNode": "Flux 1.1 [pro] Image",
"FluxKontextProImageNode": "Flux.1 Kontext Pro Image",
"FluxKontextMaxImageNode": "Flux.1 Kontext Max Image",
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
"FluxProExpandNode": "Flux.1 Expand Image",
"FluxProFillNode": "Flux.1 Fill Image",
"FluxProCannyNode": "Flux.1 Canny Control Image",

View File

@@ -406,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC):
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.pdf
GeminiMimeType.application_pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
)

View File

@@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC):
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v1"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
@@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC):
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v2"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
@@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC):
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v3"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True

View File

@@ -132,6 +132,8 @@ def poll_until_finished(
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()

View File

@@ -0,0 +1,639 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import random
import torch
from comfy_api_nodes.util.validation_utils import get_image_dimensions, validate_image_dimensions, validate_video_dimensions
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input.video_types import VideoInput
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl import VideoFromFile
import av
import io
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
MIN_WIDTH = 300
MIN_HEIGHT = 300
MAX_WIDTH = 10000
MAX_HEIGHT = 10000
MIN_VID_WIDTH = 300
MIN_VID_HEIGHT = 300
MAX_VID_WIDTH = 10000
MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
"""Verifies that the initial response contains a task ID."""
return bool(response.id)
def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise MoonvalleyApiError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status
if response and response.status
else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(prompt:str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0 :
return False, (
"The input video total frame count must be divisible by 16!"
)
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0 :
return False, (
"The input video total frame count must be divisible by 32!"
)
def validate_input_image(image: torch.Tensor, with_frame_conditioning: bool=False) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning )
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
def validate_input_video(video: VideoInput, num_frames_out: int, with_frame_conditioning: bool=False):
try:
width, height = video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e
validate_input_media(width, height, with_frame_conditioning)
validate_video_dimensions(video, min_width=MIN_VID_WIDTH, min_height=MIN_VID_HEIGHT, max_width=MAX_VID_WIDTH, max_height=MAX_VID_HEIGHT)
trimmed_video = validate_input_video_length(video, num_frames_out)
return trimmed_video
def validate_input_video_length(video: VideoInput, num_frames: int):
if video.get_duration() > 60:
raise MoonvalleyApiError("Input Video lenth should be less than 1min. Please trim.")
if num_frames == 128:
if video.get_duration() < 5:
raise MoonvalleyApiError("Input Video length is less than 5s. Please use a video longer than or equal to 5s.")
if video.get_duration() > 5:
# trim video to 5s
video = trim_video(video, 5)
if num_frames == 256:
if video.get_duration() < 10:
raise MoonvalleyApiError("Input Video length is less than 10s. Please use a video longer than or equal to 10s.")
if video.get_duration() > 10:
# trim video to 10s
video = trim_video(video, 10)
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode='r')
output_container = av.open(output_buffer, mode='w', format='mp4')
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream('h264', rate=stream.average_rate)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = 'yuv420p'
logging.info(f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps")
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream('aac', rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels")
# Calculate target frame count that's divisible by 32
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 32) * 32 # Round down to nearest multiple of 32
if target_frames == 0:
raise ValueError("Video too short: need at least 32 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {frame_count} video frames (target: {target_frames})")
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
# --- BaseMoonvalleyVideoNode ---
class BaseMoonvalleyVideoNode:
def parseWidthHeightFromRes(self, resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080},
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
return res_map[resolution]
else:
# Default to 1920x1080 if unknown
return {"width": 1920, "height": 1080}
def parseControlParameter(self, value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control"
}
if value in control_map:
return control_map[value]
else:
return control_map["Motion Transfer"]
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyTextToVideoRequest, "prompt_text",
multiline=True
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="gopro, bright, contrast, static, overexposed, bright, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, contrast, saturated, vibrant, glowing, cross dissolve, texture, videogame, saturation, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, transition, dissolve, cross-dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring, static",
),
"resolution": (IO.COMBO, {
"options": ["16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1440 x 1080)",
"3:4 (1080 x 1440)",
"21:9 (2560 x 1080)"],
"default": "16:9 (1920 x 1080)",
"tooltip": "Resolution of the output video",
}),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input(IO.FLOAT,MoonvalleyTextToVideoInferenceParams,"guidance_scale",default=7.0, step=1, min=1, max=20),
"seed": model_field_to_node_input(IO.INT,MoonvalleyTextToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
"steps": model_field_to_node_input(IO.INT, MoonvalleyTextToVideoInferenceParams, "steps", default=100, min=1, max=100),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate"
CATEGORY = "api node/video/Moonvalley Marey"
API_NODE = True
def generate(self, **kwargs):
return None
# --- MoonvalleyImg2VideoNode ---
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
@classmethod
def INPUT_TYPES(cls):
return super().INPUT_TYPES()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node"
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
image = kwargs.get("image", None)
if (image is None):
raise MoonvalleyApiError("image is required")
total_frames = get_total_frames_from_length()
validate_input_image(image,True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=total_frames,
width=width_height.get("width"),
height=width_height.get("height"),
use_negative_prompts=True
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url,
prompt_text=prompt,
inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video, )
# --- MoonvalleyVid2VidNode ---
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
for param in ["resolution", "image"]:
if param in input_types["required"]:
del input_types["required"][param]
if param in input_types["optional"]:
del input_types["optional"][param]
input_types["optional"] = {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically."}),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
)
}
return input_types
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
video = kwargs.get("video")
num_frames = get_total_frames_from_length()
if not video :
raise MoonvalleyApiError("video is required")
"""Validate video input"""
video_url=""
if video:
validated_video = validate_input_video(video, num_frames, False)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
inference_params=MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
control_params={'motion_intensity': motion_intensity}
)
control = self.parseControlParameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video, )
# --- MoonvalleyTxt2VideoNode ---
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
# Remove image-specific parameters
for param in ["image"]:
if param in input_types["optional"]:
del input_types["optional"][param]
return input_types
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
num_frames = get_total_frames_from_length()
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=num_frames,
width=width_height.get("width"),
height=width_height.get("height"),
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt,
inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video, )
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}
def get_total_frames_from_length(length="5s"):
# if length == '5s':
# return 128
# elif length == '10s':
# return 256
return 128
# else:
# raise MoonvalleyApiError("length is required")

View File

@@ -0,0 +1,152 @@
import os
from pathlib import Path
from typing import Optional
from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource
from comfy_config.types import (
ComfyConfig,
ProjectConfig,
PyProjectConfig,
PyProjectSettings
)
def validate_and_extract_os_classifiers(classifiers: list) -> list:
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
if not os_classifiers:
return []
os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}
for os_value in os_values:
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
return []
return os_values
def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
if not accelerator_classifiers:
return []
accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]
valid_accelerators = {
"GPU :: NVIDIA CUDA",
"GPU :: AMD ROCm",
"GPU :: Intel Arc",
"NPU :: Huawei Ascend",
"GPU :: Apple Metal",
}
for accelerator_value in accelerator_values:
if accelerator_value not in valid_accelerators:
return []
return accelerator_values
"""
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
This function reads and parses the pyproject.toml file in the specified directory
to extract project and ComfyUI-specific configuration information. If no
pyproject.toml file is found, it creates a minimal configuration using the
folder name as the project name. If a Python file is provided, it uses the
file name (without extension) as the project name.
Args:
path (str): Path to the directory containing the pyproject.toml file, or
path to a .py file. If pyproject.toml doesn't exist in a directory,
the folder name will be used as the default project name. If a .py
file is provided, the filename (without .py extension) will be used
as the project name.
Returns:
Optional[PyProjectConfig]: A PyProjectConfig object containing:
- project: Basic project information (name, version, dependencies, etc.)
- tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.)
Returns None if configuration extraction fails or if the provided file
is not a Python file.
Notes:
- If pyproject.toml is missing in a directory, creates a default config with folder name
- If a .py file is provided, creates a default config with filename (without extension)
- Returns None for non-Python files
Example:
>>> from comfy_config import config_parser
>>> # For directory
>>> custom_node_dir = os.path.dirname(os.path.realpath(__file__))
>>> project_config = config_parser.extract_node_configuration(custom_node_dir)
>>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml
>>>
>>> # For single-file Python node file
>>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py"
>>> project_config = config_parser.extract_node_configuration(py_file_path)
>>> print(project_config.project.name) # "my_node"
"""
def extract_node_configuration(path) -> Optional[PyProjectConfig]:
if os.path.isfile(path):
file_path = Path(path)
if file_path.suffix.lower() != '.py':
return None
project_name = file_path.stem
project = ProjectConfig(name=project_name)
comfy = ComfyConfig()
return PyProjectConfig(project=project, tool_comfy=comfy)
folder_name = os.path.basename(path)
toml_path = Path(path) / "pyproject.toml"
if not toml_path.exists():
project = ProjectConfig(name=folder_name)
comfy = ComfyConfig()
return PyProjectConfig(project=project, tool_comfy=comfy)
raw_settings = load_pyproject_settings(toml_path)
project_data = raw_settings.project
tool_data = raw_settings.tool
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
dependencies = project_data.get("dependencies", [])
supported_comfyui_frontend_version = ""
for dep in dependencies:
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
break
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
classifiers = project_data.get('classifiers', [])
supported_os = validate_and_extract_os_classifiers(classifiers)
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
project_data['supported_os'] = supported_os
project_data['supported_accelerators'] = supported_accelerators
project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version
project_data['supported_comfyui_version'] = supported_comfyui_version
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
def load_pyproject_settings(toml_path: Path) -> PyProjectSettings:
class PyProjectLoader(PyProjectSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls,
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
):
return (TomlConfigSettingsSource(settings_cls, toml_path),)
return PyProjectLoader()

97
comfy_config/types.py Normal file
View File

@@ -0,0 +1,97 @@
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Optional
# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes
# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py.
# Any changes to one must be reflected in the other to maintain consistency.
class NodeVersion(BaseModel):
changelog: str
dependencies: List[str]
deprecated: bool
id: str
version: str
download_url: str
class Node(BaseModel):
id: str
name: str
description: str
author: Optional[str] = None
license: Optional[str] = None
icon: Optional[str] = None
repository: Optional[str] = None
tags: List[str] = Field(default_factory=list)
latest_version: Optional[NodeVersion] = None
class PublishNodeVersionResponse(BaseModel):
node_version: NodeVersion
signedUrl: str
class URLs(BaseModel):
homepage: str = Field(default="", alias="Homepage")
documentation: str = Field(default="", alias="Documentation")
repository: str = Field(default="", alias="Repository")
issues: str = Field(default="", alias="Issues")
class Model(BaseModel):
location: str
model_url: str
class ComfyConfig(BaseModel):
publisher_id: str = Field(default="", alias="PublisherId")
display_name: str = Field(default="", alias="DisplayName")
icon: str = Field(default="", alias="Icon")
models: List[Model] = Field(default_factory=list, alias="Models")
includes: List[str] = Field(default_factory=list)
web: Optional[str] = None
banner_url: str = ""
class License(BaseModel):
file: str = ""
text: str = ""
class ProjectConfig(BaseModel):
name: str = ""
description: str = ""
version: str = "1.0.0"
requires_python: str = Field(default=">= 3.9", alias="requires-python")
dependencies: List[str] = Field(default_factory=list)
license: License = Field(default_factory=License)
urls: URLs = Field(default_factory=URLs)
supported_os: List[str] = Field(default_factory=list)
supported_accelerators: List[str] = Field(default_factory=list)
supported_comfyui_version: str = ""
supported_comfyui_frontend_version: str = ""
@field_validator('license', mode='before')
@classmethod
def validate_license(cls, v):
if isinstance(v, str):
return License(text=v)
elif isinstance(v, dict):
return License(**v)
elif isinstance(v, License):
return v
else:
return License()
class PyProjectConfig(BaseModel):
project: ProjectConfig = Field(default_factory=ProjectConfig)
tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig)
class PyProjectSettings(BaseSettings):
project: dict = Field(default_factory=dict)
tool: dict = Field(default_factory=dict)
model_config = SettingsConfigDict(extra='allow')

View File

@@ -1,6 +1,7 @@
import itertools
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
import nodes
@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(self, node_ids):
@abstractmethod
async def add_keys(self, node_ids):
raise NotImplementedError()
def all_node_ids(self):
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
async def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
@@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool:
return False
def add_keys(self, node_ids):
async def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
async def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
@@ -150,9 +150,10 @@ class BasicCache:
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache
self.initialized = True
@@ -201,13 +202,13 @@ class BasicCache:
else:
return None
def _ensure_subcache(self, node_id, children_ids):
async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
@@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
return await cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
@@ -273,8 +274,8 @@ class LRUCache(BasicCache):
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
@@ -303,11 +304,11 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)
await super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids)
await self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
@@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
@@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
@@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
@@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
from typing import Type, Literal
import nodes
import asyncio
import inspect
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
@@ -100,6 +102,8 @@ class TopologicalSort:
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
@@ -153,6 +157,16 @@ class TopologicalSort:
for link in links:
self.add_strong_link(*link)
def add_external_block(self, node_id):
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
self.externalBlocks += 1
self.blockCount[node_id] += 1
def unblock():
self.externalBlocks -= 1
self.blockCount[node_id] -= 1
self.unblockedEvent.set()
return unblock
def is_cached(self, node_id):
return False
@@ -181,11 +195,16 @@ class ExecutionList(TopologicalSort):
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def stage_node_execution(self):
async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
while len(available) == 0 and self.externalBlocks > 0:
# Wait for an external block to be released
await self.unblockedEvent.wait()
self.unblockedEvent.clear()
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
@@ -221,8 +240,15 @@ class ExecutionList(TopologicalSort):
return True
return False
# If an available node is async, do that first.
# This will execute the asynchronous function earlier, reducing the overall time.
def is_async(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
for node_id in node_list:
if is_output(node_id):
if is_output(node_id) or is_async(node_id):
return node_id
#This should handle the VAEDecode -> preview case

347
comfy_execution/progress.py Normal file
View File

@@ -0,0 +1,347 @@
from typing import TypedDict, Dict, Optional
from typing_extensions import override
from PIL import Image
from enum import Enum
from abc import ABC
from tqdm import tqdm
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy_execution.graph import DynamicPrompt
from protocol import BinaryEventTypes
from comfy_api import feature_flags
class NodeState(Enum):
Pending = "pending"
Running = "running"
Finished = "finished"
Error = "error"
class NodeProgressState(TypedDict):
"""
A class to represent the state of a node's progress.
"""
state: NodeState
value: float
max: float
class ProgressHandler(ABC):
"""
Abstract base class for progress handlers.
Progress handlers receive progress updates and display them in various ways.
"""
def __init__(self, name: str):
self.name = name
self.enabled = True
def set_registry(self, registry: "ProgressRegistry"):
pass
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node starts processing"""
pass
def update_handler(
self,
node_id: str,
value: float,
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
):
"""Called when a node's progress is updated"""
pass
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node finishes processing"""
pass
def reset(self):
"""Called when the progress registry is reset"""
pass
def enable(self):
"""Enable this handler"""
self.enabled = True
def disable(self):
"""Disable this handler"""
self.enabled = False
class CLIProgressHandler(ProgressHandler):
"""
Handler that displays progress using tqdm progress bars in the CLI.
"""
def __init__(self):
super().__init__("cli")
self.progress_bars: Dict[str, tqdm] = {}
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Create a new tqdm progress bar
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=state["max"],
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars),
)
@override
def update_handler(
self,
node_id: str,
value: float,
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
):
# Handle case where start_handler wasn't called
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=max_value,
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars),
)
self.progress_bars[node_id].update(value)
else:
# Update existing progress bar
if max_value != self.progress_bars[node_id].total:
self.progress_bars[node_id].total = max_value
# Calculate the update amount (difference from current position)
current_position = self.progress_bars[node_id].n
update_amount = value - current_position
if update_amount > 0:
self.progress_bars[node_id].update(update_amount)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Complete and close the progress bar if it exists
if node_id in self.progress_bars:
# Ensure the bar shows 100% completion
remaining = state["max"] - self.progress_bars[node_id].n
if remaining > 0:
self.progress_bars[node_id].update(remaining)
self.progress_bars[node_id].close()
del self.progress_bars[node_id]
@override
def reset(self):
# Close all progress bars
for bar in self.progress_bars.values():
bar.close()
self.progress_bars.clear()
class WebUIProgressHandler(ProgressHandler):
"""
Handler that sends progress updates to the WebUI via WebSockets.
"""
def __init__(self, server_instance):
super().__init__("webui")
self.server_instance = server_instance
def set_registry(self, registry: "ProgressRegistry"):
self.registry = registry
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
"""Send the current progress state to the client"""
if self.server_instance is None:
return
# Only send info for non-pending nodes
active_nodes = {
node_id: {
"value": state["value"],
"max": state["max"],
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
}
for node_id, state in nodes.items()
if state["state"] != NodeState.Pending
}
# Send a combined progress_state message with all node states
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
)
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
@override
def update_handler(
self,
node_id: str,
value: float,
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
if image:
# Only send new format if client supports it
if feature_flags.supports_feature(
self.server_instance.sockets_metadata,
self.server_instance.client_id,
"supports_preview_metadata",
):
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(
node_id
),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(
node_id
),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
}
self.server_instance.send_sync(
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
(image, metadata),
self.server_instance.client_id,
)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry:
"""
Registry that maintains node progress state and notifies registered handlers.
"""
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}
def register_handler(self, handler: ProgressHandler) -> None:
"""Register a progress handler"""
self.handlers[handler.name] = handler
def unregister_handler(self, handler_name: str) -> None:
"""Unregister a progress handler"""
if handler_name in self.handlers:
# Allow handler to clean up resources
self.handlers[handler_name].reset()
del self.handlers[handler_name]
def enable_handler(self, handler_name: str) -> None:
"""Enable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].enable()
def disable_handler(self, handler_name: str) -> None:
"""Disable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].disable()
def ensure_entry(self, node_id: str) -> NodeProgressState:
"""Ensure a node entry exists"""
if node_id not in self.nodes:
self.nodes[node_id] = NodeProgressState(
state=NodeState.Pending, value=0, max=1
)
return self.nodes[node_id]
def start_progress(self, node_id: str) -> None:
"""Start progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = 0.0
entry["max"] = 1.0
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.start_handler(node_id, entry, self.prompt_id)
def update_progress(
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
) -> None:
"""Update progress for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = value
entry["max"] = max_value
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.update_handler(
node_id, value, max_value, entry, self.prompt_id, image
)
def finish_progress(self, node_id: str) -> None:
"""Finish progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Finished
entry["value"] = entry["max"]
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.finish_handler(node_id, entry, self.prompt_id)
def reset_handlers(self) -> None:
"""Reset all handlers"""
for handler in self.handlers.values():
handler.reset()
# Global registry instance
global_progress_registry: ProgressRegistry = None
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global global_progress_registry
# Reset existing handlers if registry exists
if global_progress_registry is not None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
def add_progress_handler(handler: ProgressHandler) -> None:
registry = get_progress_state()
handler.set_registry(registry)
registry.register_handler(handler)
def get_progress_state() -> ProgressRegistry:
global global_progress_registry
if global_progress_registry is None:
from comfy_execution.graph import DynamicPrompt
global_progress_registry = ProgressRegistry(
prompt_id="", dynprompt=DynamicPrompt({})
)
return global_progress_registry

46
comfy_execution/utils.py Normal file
View File

@@ -0,0 +1,46 @@
import contextvars
from typing import Optional, NamedTuple
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
prompt_id: str
node_id: str
list_index: Optional[int]
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None)
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
)
self.token = None
def __enter__(self):
self.token = current_executing_context.set(self.context)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.token is not None:
current_executing_context.reset(self.token)

View File

@@ -133,14 +133,6 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create in-memory WAV buffer
wav_buffer = io.BytesIO()
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
wav_buffer.seek(0) # Rewind for reading
# Use PyAV to convert and add metadata
input_container = av.open(wav_buffer)
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
@@ -150,7 +142,6 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
input_container.streams.audio[0]
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
@@ -175,18 +166,16 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
# Copy frames from input to output
for frame in input_container.decode(audio=0):
frame.pts = None # Let PyAV handle timestamps
output_container.mux(out_stream.encode(frame))
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
input_container.close()
# Write the output to file
output_buffer.seek(0)

View File

@@ -40,6 +40,33 @@ class CFGZeroStar:
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
class CFGNorm:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
EXPERIMENTAL = True
def patch(self, model, strength):
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
return (m, )
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar
"CFGZeroStar": CFGZeroStar,
"CFGNorm": CFGNorm,
}

View File

@@ -2,6 +2,7 @@ import nodes
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
class EmptyCosmosLatentVideo:
@@ -75,8 +76,53 @@ class CosmosImageToVideoLatent:
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
class CosmosPredict2ImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
if end_image is not None:
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
latent_format = comfy.latent_formats.Wan21()
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
}

View File

@@ -2,6 +2,8 @@ import math
import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.k_diffusion import sa_solver
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import latent_preview
import torch
import comfy.utils
@@ -299,6 +301,35 @@ class ExtendIntermediateSigmas:
return (extended_sigmas,)
class SamplingPercentToSigma:
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
"sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}),
"return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}),
}
}
RETURN_TYPES = (IO.FLOAT,)
RETURN_NAMES = ("sigma_value",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigma"
def get_sigma(self, model, sampling_percent, return_actual_sigma):
model_sampling = model.get_model_object("model_sampling")
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
if return_actual_sigma:
if sampling_percent == 0.0:
sigma_val = model_sampling.sigma_max.item()
elif sampling_percent == 1.0:
sigma_val = model_sampling.sigma_min.item()
return (sigma_val,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@@ -480,6 +511,89 @@ class SamplerDPMAdaptative:
"s_noise":s_noise })
return (sampler, )
class SamplerER_SDE(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
"eta": (
IO.FLOAT,
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
),
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
}
}
RETURN_TYPES = (IO.SAMPLER,)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, solver_type, max_stage, eta, s_noise):
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
eta = 0
s_noise = 0
def reverse_time_sde_noise_scaler(x):
return x ** (eta + 1)
if solver_type == "ER-SDE":
# Use the default one in sample_er_sde()
noise_scaler = None
else:
noise_scaler = reverse_time_sde_noise_scaler
sampler_name = "er_sde"
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
return (sampler,)
class SamplerSASolver(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
"eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},),
"sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},),
"sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},),
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},),
"predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}),
"corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}),
"use_pece": (IO.BOOLEAN, {}),
"simple_order_2": (IO.BOOLEAN, {}),
}
}
RETURN_TYPES = (IO.SAMPLER,)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2):
model_sampling = model.get_model_object("model_sampling")
start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta)
sampler_name = "sa_solver"
sampler = comfy.samplers.ksampler(
sampler_name,
{
"tau_func": tau_func,
"s_noise": s_noise,
"predictor_order": predictor_order,
"corrector_order": corrector_order,
"use_pece": use_pece,
"simple_order_2": simple_order_2,
},
)
return (sampler,)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@@ -598,9 +712,10 @@ class CFGGuider:
return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
def set_cfg(self, cfg1, cfg2, nested=False):
self.cfg1 = cfg1
self.cfg2 = cfg2
self.nested = nested
def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
@@ -609,9 +724,21 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
def predict_noise(self, x, timestep, model_options={}, seed=None):
negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None)
positive_cond = self.conds.get("positive", None)
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
if self.nested:
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
return out[0] + self.cfg2 * (pred_text - out[0])
else:
if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.cfg2, 1.0):
negative_cond = None
if math.isclose(self.cfg1, 1.0):
middle_cond = None
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
@@ -623,6 +750,7 @@ class DualCFGGuider:
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"style": (["regular", "nested"],),
}
}
@@ -631,10 +759,10 @@ class DualCFGGuider:
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
return (guider,)
class DisableNoise:
@@ -781,11 +909,14 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SamplerER_SDE": SamplerER_SDE,
"SamplerSASolver": SamplerSASolver,
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma,
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"SamplingPercentToSigma": SamplingPercentToSigma,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,

View File

@@ -0,0 +1,26 @@
import node_helpers
class ReferenceLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
},
"optional": {"latent": ("LATENT", ),}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return (conditioning, )
NODE_CLASS_MAPPINGS = {
"ReferenceLatent": ReferenceLatent,
}

View File

@@ -1,4 +1,5 @@
import node_helpers
import comfy.utils
class CLIPTextEncodeFlux:
@classmethod
@@ -56,8 +57,52 @@ class FluxDisableGuidance:
return (c, )
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
class FluxKontextImageScale:
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE", ),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "scale"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
}

View File

@@ -71,8 +71,11 @@ class FreSca:
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args):
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
return conds_out
cond = conds_out[0]
uncond = conds_out[1]
guidance = cond - uncond
filtered_guidance = Fourier_filter(
@@ -83,7 +86,7 @@ class FreSca:
)
filtered_cond = filtered_guidance + uncond
return [filtered_cond, uncond]
return [filtered_cond, uncond] + conds_out[2:]
m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function)

View File

@@ -14,8 +14,10 @@ import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator
from comfy.comfy_types import FileLocator, IO
from server import PromptServer
MAX_RESOLUTION = nodes.MAX_RESOLUTION
@@ -229,6 +231,246 @@ class SVG:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch:
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
if image2 is None:
return (image1,)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
max_batch = max(image1.shape[0], image2.shape[0])
if image1.shape[0] < max_batch:
image1 = torch.cat(
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
)
if image2.shape[0] < max_batch:
image2 = torch.cat(
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
)
# Match image sizes if requested
if match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
aspect_ratio = w2 / h2
if direction in ["left", "right"]:
target_h, target_w = h1, int(h1 * aspect_ratio)
else: # up, down
target_w, target_h = w1, int(w1 / aspect_ratio)
image2 = comfy.utils.common_upscale(
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
pad_value = 0.0
if not isinstance(color_val, tuple):
pad_value = color_val
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
if h1 != h2:
target_h = max(h1, h2)
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
target_w = max(w1, w2)
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat(
[
image1,
torch.ones(
*image1.shape[:-1],
max_channels - image1.shape[-1],
device=image1.device,
),
],
dim=-1,
)
if image2.shape[-1] < max_channels:
image2 = torch.cat(
[
image2,
torch.ones(
*image2.shape[:-1],
max_channels - image2.shape[-1],
device=image2.device,
),
],
dim=-1,
)
# Add spacing if specified
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
max(image1.shape[1], image2.shape[1]),
spacing_width,
image1.shape[-1],
)
else:
spacing_shape = (
image1.shape[0],
spacing_width,
max(image1.shape[2], image2.shape[2]),
image1.shape[-1],
)
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
if isinstance(color_val, tuple):
for i, c in enumerate(color_val):
if i < spacing.shape[-1]:
spacing[..., i] = c
if spacing.shape[-1] == 4: # Add alpha
spacing[..., 3] = 1.0
else:
spacing[..., : min(3, spacing.shape[-1])] = color_val
if spacing.shape[-1] == 4:
spacing[..., 3] = 1.0
# Concatenate images
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
if spacing_width > 0:
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class ResizeAndPadImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"target_width": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"target_height": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "resize_and_pad"
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width
scale_h = target_height / orig_height
scale = min(scale_w, scale_h)
new_width = int(orig_width * scale)
new_height = int(orig_height * scale)
image_permuted = image.permute(0, 3, 1, 2)
resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
pad_value = 0.0 if padding_color == "black" else 1.0
padded = torch.full(
(batch_size, channels, target_height, target_width),
pad_value,
dtype=image.dtype,
device=image.device
)
y_offset = (target_height - new_height) // 2
x_offset = (target_width - new_width) // 2
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1)
return (output,)
class SaveSVGNode:
"""
Save SVG files on disk.
@@ -310,6 +552,80 @@ class SaveSVGNode:
counter += 1
return { "ui": { "images": results } }
class GetImageSize:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
RETURN_NAMES = ("width", "height", "batch_size")
FUNCTION = "get_size"
CATEGORY = "image"
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
def get_size(self, image, unique_id=None) -> tuple[int, int]:
height = image.shape[1]
width = image.shape[2]
batch_size = image.shape[0]
# Send progress text to display size on the node
if unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
return width, height, batch_size
class ImageRotate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": (IO.IMAGE,),
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
}}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "rotate"
CATEGORY = "image/transform"
def rotate(self, image, rotation):
rotate_by = 0
if rotation.startswith("90"):
rotate_by = 1
elif rotation.startswith("180"):
rotate_by = 2
elif rotation.startswith("270"):
rotate_by = 3
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
return (image,)
class ImageFlip:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": (IO.IMAGE,),
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
}}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "flip"
CATEGORY = "image/transform"
def flip(self, image, flip_method):
if flip_method.startswith("x"):
image = torch.flip(image, dims=[1])
elif flip_method.startswith("y"):
image = torch.flip(image, dims=[2])
return (image,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
@@ -318,4 +634,9 @@ NODE_CLASS_MAPPINGS = {
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
"ResizeAndPadImage": ResizeAndPadImage,
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
}

View File

@@ -5,6 +5,8 @@ import os
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
from pathlib import Path
def normalize_path(path):
return path.replace('\\', '/')
@@ -16,7 +18,14 @@ class Load3D():
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
@@ -61,7 +70,14 @@ class Load3DAnimation():
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))]
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),

View File

@@ -134,8 +134,8 @@ class LTXVAddGuide:
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1:
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
if guide_length > 1 and frame_idx != 0:
frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
@@ -144,7 +144,7 @@ class LTXVAddGuide:
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
keyframe_idxs = pixel_coords

View File

@@ -152,7 +152,7 @@ class ImageColorToMask:
def image_to_mask(self, image, color):
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 255, 0).float()
mask = torch.where(temp == color, 1.0, 0).float()
return (mask,)
class SolidMask:
@@ -247,7 +247,7 @@ class MaskComposite:
visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:, :visible_height, :visible_width]
destination_portion = destination[:, top:bottom, left:right]
destination_portion = output[:, top:bottom, left:right]
if operation == "multiply":
output[:, top:bottom, left:right] = destination_portion * source_portion

View File

@@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
@@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM:
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
latent_format = None
sigma_data = 1.0
if sampling == "eps":
@@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM:
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
elif sampling == "cosmos_rflow":
sampling_type = comfy.model_sampling.COSMOS_RFLOW
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)

View File

@@ -268,6 +268,52 @@ class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -281,4 +327,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
}

View File

@@ -4,6 +4,7 @@ import comfy.sampler_helpers
import comfy.samplers
import comfy.utils
import node_helpers
import math
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
pos = noise_pred_pos - noise_pred_nocond
@@ -69,8 +70,23 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
negative_cond = self.conds.get("negative", None)
empty_cond = self.conds.get("empty_negative_prompt", None)
(noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.neg_scale, 0.0):
negative_cond = None
if math.isclose(self.cfg, 1.0):
empty_cond = None
conds = [positive_cond, negative_cond, empty_cond]
out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options)
# Apply pre_cfg_functions since sampling_function() is skipped
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": self.cfg, "timestep": timestep,
"input": x, "sigma": timestep, "model": self.inner_model, "model_options": model_options}
out = fn(args)
noise_pred_pos, noise_pred_neg, noise_pred_empty = out
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
# normally this would be done in cfg_function, but we skipped
@@ -82,6 +98,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
"denoised": cfg_result,
"cond": positive_cond,
"uncond": negative_cond,
"cond_scale": self.cfg,
"model": self.inner_model,
"uncond_denoised": noise_pred_neg,
"cond_denoised": noise_pred_pos,

View File

@@ -1,24 +1,24 @@
from nodes import MAX_RESOLUTION
class CLIPTextEncodePixArtAlpha:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
}
from nodes import MAX_RESOLUTION
class CLIPTextEncodePixArtAlpha:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
}

View File

@@ -78,7 +78,75 @@ class SkipLayerGuidanceDiT:
return (m, )
class SkipLayerGuidanceDiTSimple:
'''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
double_layers = re.findall(r'\d+', double_layers)
double_layers = [int(i) for i in double_layers]
single_layers = re.findall(r'\d+', single_layers)
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )
def calc_cond_batch_function(args):
x = args["input"]
model = args["model"]
conds = args["conds"]
sigma = args["sigma"]
model_options = args["model_options"]
slg_model_options = model_options.copy()
for layer in double_layers:
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "double_block", layer)
for layer in single_layers:
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "single_block", layer)
cond, uncond = conds
sigma_ = sigma[0].item()
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
out = [cond_out, uncond_out]
else:
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
return out
m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
}

View File

@@ -296,6 +296,41 @@ class RegexExtract():
return result,
class RegexReplace():
DESCRIPTION = "Find and replace text using regex patterns."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True}),
},
"optional": {
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
return result,
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
@@ -306,7 +341,8 @@ NODE_CLASS_MAPPINGS = {
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract
"RegexExtract": RegexExtract,
"RegexReplace": RegexReplace,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -319,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract"
"RegexExtract": "Regex Extract",
"RegexReplace": "Regex Replace",
}

View File

@@ -0,0 +1,71 @@
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
import torch
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
"""Drop tangential components from uncond score to align with cond score."""
# (B, 1, ...)
batch_num = cond_score.shape[0]
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
# Score matrix A (B, 2, ...)
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
try:
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
except RuntimeError:
# Fallback to CPU
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
# Drop the tangential components
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
class TCFG(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
}
}
RETURN_TYPES = (IO.MODEL,)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
DESCRIPTION = "TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
def patch(self, model):
m = model.clone()
def tangential_damping_cfg(args):
# Assume [cond, uncond, ...]
x = args["input"]
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
# Skip when either cond or uncond is None
return conds_out
cond_pred = conds_out[0]
uncond_pred = conds_out[1]
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
uncond_pred_td = x - uncond_td
return [cond_pred, uncond_pred_td] + conds_out[2:]
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return (m,)
NODE_CLASS_MAPPINGS = {
"TCFG": TCFG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TCFG": "Tangential Damping CFG",
}

852
comfy_extras/nodes_train.py Normal file
View File

@@ -0,0 +1,852 @@
import datetime
import json
import logging
import os
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint
import tqdm
import comfy.samplers
import comfy.sd
import comfy.utils
import comfy.model_management
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.seed = seed
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
class LoadImageSetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": (
[
f
for f in os.listdir(folder_paths.get_input_directory())
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
],
{"image_upload": True, "allow_batch": True},
)
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
INPUT_IS_LIST = True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
@classmethod
def VALIDATE_INPUTS(s, images, resize_method):
filenames = images[0] if isinstance(images[0], list) else images
for image in filenames:
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
def load_images(self, input_files, resize_method):
input_dir = folder_paths.get_input_directory()
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
image_files = [
f
for f in input_files
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
return (output_tensor,)
class LoadImageSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
def load_images(self, folder, resize_method):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
return (output_tensor,)
class LoadImageTextSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
"width": (
IO.INT,
{
"default": -1,
"min": -1,
"max": 10000,
"step": 1,
"tooltip": "The width to resize the images to. -1 means use the original width.",
},
),
"height": (
IO.INT,
{
"default": -1,
"min": -1,
"max": 10000,
"step": 1,
"tooltip": "The height to resize the images to. -1 means use the original height.",
},
)
},
}
RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
def load_images(self, folder, clip, resize_method, width=None, height=None):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
logging.info(f"Loading images from folder: {folder}")
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = []
for item in os.listdir(sub_input_dir):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in valid_extensions):
image_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
image_files.extend([
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
] * repeat)
caption_file_path = [
f.replace(os.path.splitext(f)[1], ".txt")
for f in image_files
]
captions = []
for caption_file in caption_file_path:
caption_path = os.path.join(sub_input_dir, caption_file)
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
captions.append(caption)
else:
captions.append("")
width = width if width != -1 else None
height = height if height != -1 else None
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
logging.info(f"Encoding captions from {sub_input_dir}.")
conditions = []
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
for text in captions:
if text == "":
conditions.append(empty_cond)
tokens = clip.tokenize(text)
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
return (output_tensor, conditions)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
fwd, args, kwargs, use_reentrant=False
)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
class TrainLoraNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
"latents": (
"LATENT",
{
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
},
),
"positive": (
IO.CONDITIONING,
{"tooltip": "The positive conditioning to use for training."},
),
"batch_size": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 10000,
"step": 1,
"tooltip": "The batch size to use for training.",
},
),
"steps": (
IO.INT,
{
"default": 16,
"min": 1,
"max": 100000,
"tooltip": "The number of steps to train the LoRA for.",
},
),
"learning_rate": (
IO.FLOAT,
{
"default": 0.0005,
"min": 0.0000001,
"max": 1.0,
"step": 0.000001,
"tooltip": "The learning rate to use for training.",
},
),
"rank": (
IO.INT,
{
"default": 8,
"min": 1,
"max": 128,
"tooltip": "The rank of the LoRA layers.",
},
),
"optimizer": (
["AdamW", "Adam", "SGD", "RMSprop"],
{
"default": "AdamW",
"tooltip": "The optimizer to use for training.",
},
),
"loss_function": (
["MSE", "L1", "Huber", "SmoothL1"],
{
"default": "MSE",
"tooltip": "The loss function to use for training.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
},
),
"training_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for training."},
),
"lora_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."},
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
"default": "[None]",
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
},
),
},
}
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
FUNCTION = "train"
CATEGORY = "training"
EXPERIMENTAL = True
def train(
self,
model,
latents,
positive,
batch_size,
steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
existing_lora,
):
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(
f"{key}.dora_scale", None
)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None
adapter_cls = adapters[0]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
# Setup sampler and guider like in test script
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
total_steps=steps,
seed=seed,
training_dtype=dtype
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return (mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora_model"
CATEGORY = "loaders"
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
EXPERIMENTAL = True
def load_lora_model(self, model, lora, strength_model):
if strength_model == 0:
return (model, )
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
return (model_lora, )
class SaveLoRA:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (
IO.LORA_MODEL,
{
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
},
),
"prefix": (
"STRING",
{
"default": "loras/ComfyUI_trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
},
"optional": {
"steps": (
IO.INT,
{
"forceInput": True,
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
},
),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "loaders"
EXPERIMENTAL = True
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return {}
class LossGraphNode:
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"loss": (IO.LOSS_MAP, {"default": {}}),
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "plot_loss"
OUTPUT_NODE = True
CATEGORY = "training"
EXPERIMENTAL = True
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img.save(
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
pnginfo=metadata,
)
return {
"ui": {
"images": [
{
"filename": f"{filename_prefix}_{date}.png",
"subfolder": "",
"type": "temp",
}
]
}
}
NODE_CLASS_MAPPINGS = {
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoraModelLoader": LoraModelLoader,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA",
"SaveLoRANode": "Save LoRA Weights",
"LoraModelLoader": "Load LoRA Model",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}

View File

@@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage):
def load_capture(self, image, **kwargs):
return super().load_image(folder_paths.get_annotated_filepath(image))
@classmethod
def IS_CHANGED(cls, image, width, height, capture_on_queue):
return super().IS_CHANGED(image)
NODE_CLASS_MAPPINGS = {
"WebcamCapture": WebcamCapture,

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.39"
__version__ = "0.3.45"

View File

@@ -1,22 +1,38 @@
import sys
import copy
import logging
import threading
import heapq
import inspect
import logging
import sys
import threading
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
import asyncio
import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
import nodes
from comfy_execution.caching import (
BasicCache,
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
HierarchicalCache,
LRUCache,
)
from comfy_execution.graph import (
DynamicPrompt,
ExecutionBlocker,
ExecutionList,
get_input_info,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
class ExecutionResult(Enum):
SUCCESS = 0
@@ -27,12 +43,13 @@ class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
def get(self, node_id):
async def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
@@ -50,7 +67,8 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
@@ -105,6 +123,8 @@ class CacheSet:
}
return result
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@@ -152,7 +172,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
async def resolve_map_node_over_list_results(results):
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
if len(remaining) == 0:
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
else:
done, pending = await asyncio.wait(remaining)
for task in done:
exc = task.exception()
if exc is not None:
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -166,7 +198,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None, input_is_list=False):
async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
@@ -182,20 +214,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs))
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
result = task.result()
results.append(result)
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index):
result = f(**inputs)
results.append(result)
else:
results.append(execution_block)
if input_is_list:
process_inputs(input_data_all, 0, input_is_list=input_is_list)
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0:
process_inputs({})
await process_inputs({})
else:
for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i)
await process_inputs(input_dict, i)
return results
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
@@ -217,11 +266,18 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False
def get_output_from_returns(return_values, obj):
results = []
uis = []
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
@@ -255,6 +311,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
else:
output = []
ui = dict()
# TODO: Think there's an existing bug here
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph
@@ -267,7 +327,7 @@ def format_value(x):
else:
return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@@ -279,11 +339,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
if unique_id in pending_subgraph_results:
if unique_id in pending_async_nodes:
results = []
for r in pending_async_nodes[unique_id]:
if isinstance(r, asyncio.Task):
try:
results.append(r.result())
except Exception as ex:
# An async task failed - propagate the exception up
del pending_async_nodes[unique_id]
raise ex
else:
results.append(r)
del pending_async_nodes[unique_id]
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
elif unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
@@ -305,6 +380,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = []
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
@@ -316,7 +392,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
@@ -345,8 +422,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else:
return block
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
@@ -389,7 +476,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
@@ -417,20 +505,24 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
logging.error(f"!!! Exception during processing !!! {ex}")
logging.error(traceback.format_exc())
tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
error_details = {
"node_id": real_node_id,
"exception_message": str(ex),
"exception_message": "{}\n{}".format(ex, tips),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (ExecutionResult.FAILURE, error_details, ex)
get_progress_state().finish_progress(unique_id)
executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None)
@@ -485,6 +577,11 @@ class PromptExecutor:
self.add_message("execution_error", mes, broadcast=False)
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
if "client_id" in extra_data:
@@ -497,9 +594,11 @@ class PromptExecutor:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
cached_nodes = []
@@ -512,6 +611,7 @@ class PromptExecutor:
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@@ -519,12 +619,13 @@ class PromptExecutor:
execution_list.add_node(node_id)
while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution()
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -554,7 +655,7 @@ class PromptExecutor:
comfy.model_management.unload_all_models()
def validate_inputs(prompt, item, validated):
async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item
if unique_id in validated:
return validated[unique_id]
@@ -631,7 +732,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
r = await validate_inputs(prompt_id, prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
@@ -756,7 +857,8 @@ def validate_inputs(prompt, item, validated):
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker):
@@ -789,7 +891,7 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt):
async def validate_prompt(prompt_id, prompt):
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@@ -832,7 +934,7 @@ def validate_prompt(prompt):
valid = False
reasons = []
try:
m = validate_inputs(prompt, o, validated)
m = await validate_inputs(prompt_id, prompt, o, validated)
valid = m[0]
reasons = m[1]
except Exception as ex:
@@ -945,6 +1047,11 @@ class PromptQueue:
if status is not None:
status_dict = copy.deepcopy(status._asdict())
# Remove sensitive data from extra_data before storing in history
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val)
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": {},

View File

@@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
"""
Get the full path of a file in a folder, has to be a file
"""
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
@@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix
def get_input_subfolders() -> list[str]:
"""Returns a list of all subfolder paths in the input directory, recursively.
Returns:
List of folder paths relative to the input directory, excluding the root directory
"""
input_dir = get_input_directory()
folders = []
try:
if not os.path.exists(input_dir):
return []
for root, dirs, _ in os.walk(input_dir):
rel_path = os.path.relpath(root, input_dir)
if rel_path != ".": # Only include non-root directories
# Normalize path separators to forward slashes
folders.append(rel_path.replace(os.sep, '/'))
return sorted(folders)
except FileNotFoundError:
return []

Some files were not shown because too many files have changed in this diff Show More