Compare commits

...

123 Commits

Author SHA1 Message Date
comfyanonymous
e18f53cca9 ComfyUI version 0.3.43 2025-06-27 17:22:02 -04:00
comfyanonymous
c36be0ea09 Fix memory estimation bug with kontext. (#8709) 2025-06-27 17:21:12 -04:00
comfyanonymous
9093301a49 Don't add tiny bit of random noise when VAE encoding. (#8705)
Shouldn't change outputs but might make things a tiny bit more
deterministic.
2025-06-27 14:14:56 -04:00
comfyanonymous
bd951a714f Add Flux Kontext and Omnigen 2 models to readme. (#8682) 2025-06-26 12:26:29 -04:00
comfyanonymous
6493709d6a ComfyUI version 0.3.42 2025-06-26 11:47:07 -04:00
filtered
b976f934ae Update frontend to 1.23.4 (#8681) 2025-06-26 11:44:12 -04:00
comfyanonymous
7d8cf4cacc Update requirements.txt (#8680) 2025-06-26 11:39:40 -04:00
filtered
68f4496b8e Update frontend to 1.23.3 (#8678) 2025-06-26 11:29:03 -04:00
comfyanonymous
ef5266b1c1 Support Flux Kontext Dev model. (#8679) 2025-06-26 11:28:41 -04:00
comfyanonymous
a96e65df18 Disable omnigen2 fp16 on older pytorch versions. (#8672) 2025-06-26 03:39:09 -04:00
comfyanonymous
93a49a45de Bump minimum transformers version. (#8671) 2025-06-26 02:33:02 -04:00
comfyanonymous
ec70ed6aea Omnigen2 model implementation. (#8669) 2025-06-25 19:35:57 -04:00
comfyanonymous
7a13f74220 unet -> diffusion model (#8659) 2025-06-25 04:52:34 -04:00
chaObserv
8042eb20c6 Singlestep DPM++ SDE for RF (#8627)
Refactor the algorithm, and apply alpha scaling.
2025-06-24 14:59:09 -04:00
comfyanonymous
bd9f166c12 Cosmos predict2 model merging nodes. (#8647) 2025-06-24 05:17:16 -04:00
comfyanonymous
dd94416db2 Indicate that directml is not recommended in the README. (#8644) 2025-06-23 14:04:49 -04:00
comfyanonymous
ae0e7c4dff Resize and pad image node. (#8636) 2025-06-22 17:59:31 -04:00
comfyanonymous
78f79266a9 Allow padding in ImageStitch node to be white. (#8631) 2025-06-22 00:19:41 -04:00
comfyanonymous
1883e70b43 Fix exception when using a noise mask with cosmos predict2. (#8621)
* Fix exception when using a noise mask with cosmos predict2.

* Fix ruff.
2025-06-21 03:30:39 -04:00
Lucas - BLOCK33
31ca603ccb Improve the log time function for 10 minute + renders (#6207)
* modified:   main.py

* Update main.py
2025-06-20 23:04:55 -04:00
comfyanonymous
f7fb193712 Small flux optimization. (#8611) 2025-06-20 05:37:32 -04:00
comfyanonymous
7e9267fa77 Make flux controlnet work with sd3 text enc. (#8599) 2025-06-19 18:50:05 -04:00
comfyanonymous
91d40086db Fix pytorch warning. (#8593) 2025-06-19 11:04:52 -04:00
coderfromthenorth93
5b12b55e32 Add new fields to the config types (#8507) 2025-06-18 15:12:29 -04:00
comfyanonymous
e9e9a031a8 Show a better error when the workflow OOMs. (#8574) 2025-06-18 06:55:21 -04:00
filtered
d7430c529a Update frontend to 1.22.2 (#8567) 2025-06-17 18:58:28 -04:00
ComfyUI Wiki
cd88f709ab Update template version (#8563) 2025-06-17 04:11:59 -07:00
comfyanonymous
4459a17e82 Add Cosmos Predict2 to README. (#8562) 2025-06-17 05:18:01 -04:00
comfyanonymous
483b3e62e0 ComfyUI version v0.3.41 2025-06-16 23:34:46 -04:00
chaObserv
8e81c507d2 Multistep DPM++ SDE samplers for RF (#8541)
Include alpha in sampling and minor refactoring
2025-06-16 14:47:10 -04:00
comfyanonymous
e1c6dc720e Allow setting min_length with tokenizer_data. (#8547) 2025-06-16 13:43:52 -04:00
comfyanonymous
7ea79ebb9d Add correct eps to ltxv rmsnorm. (#8542) 2025-06-15 12:21:25 -04:00
comfyanonymous
ae75a084df SaveLora now saves in the same filename format as all the other nodes. (#8538) 2025-06-15 03:44:59 -04:00
comfyanonymous
d6a2137fc3 Support Cosmos predict2 image to video models. (#8535)
Use the CosmosPredict2ImageToVideoLatent node.
2025-06-14 21:37:07 -04:00
chaObserv
53e8d8193c Generalize SEEDS samplers (#8529)
Restore VP algorithm for RF and refactor noise_coeffs and half-logSNR calculations
2025-06-14 16:58:16 -04:00
comfyanonymous
29596bd53f Small cosmos attention code refactor. (#8530) 2025-06-14 05:02:05 -04:00
Terry Jia
803af1e0c3 allow extra settings from pyproject.toml (#8526) 2025-06-13 23:11:55 -04:00
ComfyUI Wiki
6673939e76 Bump template to 0.1.28 (#8510) 2025-06-13 23:11:00 -04:00
ComfyUI Wiki
f74778e75d Bump embedded docs to 0.2.2 (#8512) 2025-06-13 23:06:28 -04:00
Kohaku-Blueleaf
520eb77b72 LoRA Trainer: LoRA training node in weight adapter scheme (#8446) 2025-06-13 19:25:59 -04:00
comfyanonymous
5bf69bde35 Add cosmos_rflow option to ModelSamplingContinuousEDM node. (#8523)
This is for the cosmos predict2 model.
2025-06-13 17:47:52 -04:00
comfyanonymous
c69af655aa Uncap cosmos predict2 res and fix mem estimation. (#8518) 2025-06-13 07:30:18 -04:00
comfyanonymous
251f54a2ad Basic initial support for cosmos predict2 text to image 2B and 14B models. (#8517) 2025-06-13 07:05:23 -04:00
Christian Byrne
c6529c0d77 don't validate string inputs with VALIDATE_INPUTS (#8508) 2025-06-12 20:17:10 -04:00
filtered
baa8c8cdd3 Add '@prerelease' to use latest test frontend (#8501)
* Add '@prerelease' to use latest test frontend

Allows download of pre-release versions.

Will always get the latest pre-release version - even if it's older than the latest stable release.

* nit
2025-06-12 17:03:27 -07:00
comfyanonymous
40fd39c7cb debug -> warning (#8506) 2025-06-12 17:14:59 -04:00
Terry Jia
4d1c4b9797 Auto register web folder (#8505)
* auto register web folder from pyproject

* need pydantic-settings as dependency

* wrapped try/except for config_parser

* sf
2025-06-12 16:24:39 -04:00
comfyanonymous
d2566eb4b2 Add a warning for old python versions. (#8504) 2025-06-12 15:38:33 -04:00
filtered
ef7e885fe4 Revert "Update requirements.txt (#8487)" (#8502)
This reverts commit 373a9386a4.
2025-06-12 14:10:48 -04:00
filtered
ecb8d15e7a Allow specifying any frontend semver suffixes (#8498) 2025-06-11 21:41:30 -04:00
comfyanonymous
365f9ed157 Revert "auto register web folder from pyproject (#8478)" (#8497)
This reverts commit 9685d4f3c3.
2025-06-11 17:28:04 -04:00
pythongosssss
50c605e957 Add support for sqlite database (#8444)
* Add support for sqlite database

* fix
2025-06-11 16:43:39 -04:00
Terry Jia
9685d4f3c3 auto register web folder from pyproject (#8478)
* auto register web folder from pyproject

* need pydantic-settings as dependency
2025-06-11 16:21:28 -04:00
comfyanonymous
8a4ff747bd Fix mistake in last commit. (#8496)
* Move to right place.
2025-06-11 15:13:29 -04:00
comfyanonymous
af1eb58be8 Fix black images on some flux models in fp16. (#8495) 2025-06-11 15:09:11 -04:00
ComfyUI Wiki
373a9386a4 Update requirements.txt (#8487) 2025-06-11 05:10:46 -04:00
comfyanonymous
6e28a46454 Apple most likely is never fixing the fp16 attention bug. (#8485) 2025-06-10 13:06:24 -04:00
Kent Mewhort
c7b25784b1 Fix WebcamCapture IS_CHANGED signature (#8413) 2025-06-09 13:05:54 -04:00
comfyanonymous
7f800d04fa Enable AMD fp8 and pytorch attention on some GPUs. (#8474)
Information is from the pytorch source code.
2025-06-09 12:50:39 -04:00
comfyanonymous
97755eed46 Enable fp8 ops by default on gfx1201 (#8464) 2025-06-08 14:15:34 -04:00
comfyanonymous
daf9d25ee2 Cleaner torch version comparisons. (#8453) 2025-06-07 10:01:15 -04:00
comfyanonymous
3b4b171e18 Alternate fix for #8435 (#8442) 2025-06-06 09:43:27 -04:00
Olexandr88
d8759c772b Update README.md (#8427) 2025-06-05 10:44:29 -07:00
comfyanonymous
4248b1618f Let chroma TE work on regular flux. (#8429) 2025-06-05 10:07:17 -04:00
comfyanonymous
866f6cdab4 ComfyUI version 0.3.40 2025-06-04 22:18:54 -04:00
Christian Byrne
3aa83feeec [refactor] remove version prefixes from Ideogram node categories (#8418)
Simplifies node organization by consolidating all Ideogram nodes under a single category instead of version-specific subcategories.
2025-06-04 21:56:38 -04:00
comfyanonymous
871749c208 Add batch to GetImageSize node. (#8419) 2025-06-04 09:40:21 -04:00
SD
fcc1643c52 Sub call to deprecated pillow API Image.ANTIALIAS (#8415)
ANTIALIAS was removed in Pillow 10.0.0
2025-06-04 09:03:42 -04:00
filtered
20687293fe Update frontend to 1.21.7 (#8410) 2025-06-04 08:57:13 -04:00
Terry Jia
47d55b8b45 add support to read pyproject.toml from custom node (#8357)
* add support to read pyproject.toml from custom node

* sf

* use pydantic instead

* sf

* use pydantic_settings

* remove unnecessary try/catch and handle single-file python node

* sf
2025-06-03 19:59:13 -04:00
comfyanonymous
310f4b6ef8 Add api nodes to readme. (#8402) 2025-06-03 04:26:44 -04:00
Christian Byrne
856448060c [feat] Add GetImageSize node (#8386)
* [feat] Add GetImageSize node to return image dimensions

Added a simple GetImageSize node in comfy_extras/nodes_images.py that returns width and height of input images. The node displays dimensions on the UI via PromptServer and provides width/height as outputs for further processing.

* add display name mapping

* [fix] Add server module mock to unit tests for PromptServer import

Updated test to mock server module preventing import errors from the new PromptServer usage in GetImageSize node. Uses direct import pattern consistent with rest of codebase.
2025-06-02 21:57:50 -04:00
comfyanonymous
312d511630 Style fix. (#8390) 2025-06-02 07:22:02 -04:00
Jesse Gonyou
4f4f1c642a Update fix for potential XSS on /view (#8384)
* Update fix for potential XSS on /view

This commit uses mimetypes to add more restricted filetypes to prevent from being served, since mimetypes are what browsers use to determine how to serve files.

* Fix typo

Fixed a typo that prevented the program from running
2025-06-02 06:52:44 -04:00
filtered
010954d277 [BugFix] Update frontend to 1.21.6 (#8383) 2025-06-02 14:57:44 +10:00
filtered
6d46bb4b4c [BugFix] Update frontend to 1.21.5 (#8382) 2025-06-01 16:47:14 -04:00
Christian Byrne
67f57c5bcc [feat] add custom node testing requirement to issue templates (#8374)
Adds mandatory checkbox to bug report and user support templates requiring users to confirm they've tested with custom nodes disabled before submitting issues.
2025-06-01 15:47:07 -04:00
filtered
fd943c928f [BugFix] Update frontend to 1.21.4 (#8377) 2025-06-01 13:57:53 -04:00
ComfyUI Wiki
d3bd983b91 Bump template to 0.1.25 (#8372) 2025-06-01 05:41:17 -04:00
comfyanonymous
fb4754624d Make the casting in lists the same as regular inputs. (#8373) 2025-06-01 05:39:54 -04:00
Benjamin Lu
180db6753f Add Help Menu in NodeLibrarySidebarTab (#8179) 2025-06-01 04:32:32 -04:00
Christian Byrne
d062fcc5c0 [feat] Add ImageStitch node for concatenating images (#8369)
* [feat] Add ImageStitch node for concatenating images with borders

Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing.

Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage.

* [fix] Fix CI issues with CUDA dependencies and linting

- Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners
- Fix ruff linting issues for code style compliance

* [fix] Improve CI compatibility by mocking nodes module import

Prevent CUDA initialization chain by mocking the nodes module at import time,
which is cleaner than deep mocking of CUDA-specific functions.

* [refactor] Clean up ImageStitch tests

- Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini)
- Remove metadata tests that test framework internals rather than functionality
- Rename complex scenario test to be more descriptive of what it tests

* [refactor] Rename 'border' to 'spacing' for semantic accuracy

- Change border_width/border_color to spacing_width/spacing_color in API
- Update all tests to use spacing terminology
- Update comments and variable names throughout
- More accurately describes the gap/separator between images
2025-06-01 04:28:52 -04:00
filtered
456abad834 Update frontend to 1.21 (#8366) 2025-06-01 01:10:04 -04:00
comfyanonymous
19e45e9b0e Make it easier to pass lists of tensors to models. (#8358) 2025-05-31 20:00:20 -04:00
ComfyUI Wiki
97f23b81f3 Bump template to 0.1.23 (#8353)
Correct some error settings in VACE
2025-05-30 23:05:42 -07:00
drhead
08b7cc7506 use fused multiply-add pointwise ops in chroma (#8279) 2025-05-30 18:09:54 -04:00
BennyKok
6c319cbb4e fix: custom comfy-api-base works with subpath (#8332) 2025-05-30 17:51:28 -04:00
Chenlei Hu
df1aebe52e Remove huchenlei from CODEOWNERS (#8350) 2025-05-30 17:27:52 -04:00
comfyanonymous
704fc78854 Put ROCm version in tuple to make it easier to enable stuff based on it. (#8348) 2025-05-30 15:41:02 -04:00
JettHu
1d9fee79fd Add node for regex replace(sub) operation (#8340)
* Add node for regex replace(sub) operation

* Apply suggestions from code review

add tooltips

Co-authored-by: Christian Byrne <abolkonsky.rem@gmail.com>

* Fix indentation

---------

Co-authored-by: Christian Byrne <abolkonsky.rem@gmail.com>
2025-05-30 15:08:59 -04:00
Jedrzej Kosinski
aeba0b3a26 Reduce code duplication for [pro] and [max], rename Pro and Max to [pro] and [max] to be consistent with other BFL nodes, make default seed for Kontext nodes be 1234. since 0 is interpreted by API as 'choose random seed' (#8337) 2025-05-29 17:14:27 -04:00
comfyanonymous
094306b626 ComfyUI version 0.3.39 2025-05-29 14:26:39 -04:00
filtered
31260f0275 Update templates 0.1.22 (#8334) 2025-05-30 03:52:27 +10:00
Robin Huang
f1c9ca816a Add BFL Kontext API Nodes. (#8333)
* Added initial Flux.1 Kontext Pro Image node - recreated branch to save myself sanity from rebase crap after master got rebased

* Add safety filter to Kontext.

* Make safety = 2 and input image is optional.

* Add BFL kontext API nodes.

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-05-29 13:27:40 -04:00
comfyanonymous
f2289a1f59 Delete useless file. (#8327) 2025-05-29 08:29:37 -04:00
Robin Huang
fb83eda287 Revert "Add support for Veo3 API node." (#8322)
This reverts commit 592d056100.
2025-05-29 03:03:11 -04:00
comfyanonymous
5e5e46d40c Not really tested WAN Phantom Support. (#8321) 2025-05-28 23:46:15 -04:00
Yoland Yan
4eba3161cf Refactor Pika API node imports and fix unique_id issue. (#8319)
Added unique_id to hidden parameters and corrected description formatting in PikAdditionsNode.
2025-05-28 23:42:25 -04:00
Robin Huang
592d056100 Add support for Veo3 API node. (#8320) 2025-05-28 23:42:02 -04:00
comfyanonymous
1c1687ab1c Support HiDream SimpleTuner loras. (#8318) 2025-05-28 18:47:15 -04:00
comfyanonymous
e6609dacde ComfyUI version 0.3.38 2025-05-28 02:15:11 -04:00
Christian Byrne
ba37e67964 update frontend patch 1.20.7 (#8312) 2025-05-28 01:42:18 -04:00
comfyanonymous
06c661004e Memory estimation code can now take into account conds. (#8307) 2025-05-27 15:09:05 -04:00
comfyanonymous
c9e1821a7b ComfyUI version 0.3.37 2025-05-27 07:07:44 -04:00
Robin Huang
f58f0f5696 More API nodes: Gemini/Open AI Chat, Tripo, Rodin, Runway Image (#8295)
* Add Ideogram generate node.

* Add staging api.

* Add API_NODE and common error for missing auth token (#5)

* Add Minimax Video Generation + Async Task queue polling example (#6)

* [Minimax] Show video preview and embed workflow in ouput (#7)

* Remove uv.lock

* Remove polling operations.

* Revert "Remove polling operations."

This reverts commit 8415404ce8fbc0262b7de54fc700c5c8854a34fc.

* Update stubs.

* Added Ideogram and Minimax back in.

* Added initial BFL Flux 1.1 [pro] Ultra node (#11)

* Manually add BFL polling status response schema (#15)

* Add function for uploading files. (#18)

* Add Luma nodes (#16)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Refactor util functions (#20)

* Add rest of Luma node functionality (#19)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Fix image_luma_ref not working (#28)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* [Bug] Remove duplicated option T2V-01 in MinimaxTextToVideoNode (#31)

* add veo2, bump av req (#32)

* Add Recraft nodes (#29)

* Add Kling Nodes (#12)

* Add Camera Concepts (luma_concepts) to Luma Video nodes (#33)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Add Runway nodes (#17)

* Convert Minimax node to use VIDEO output type (#34)

* Standard `CATEGORY` system for api nodes (#35)

* Set `Content-Type` header when uploading files (#36)

* add better error propagation to veo2 (#37)

* Add Realistic Image and Logo Raster styles for Recraft v3 (#38)

* Fix runway image upload and progress polling (#39)

* Fix image upload for Luma: only include `Content-Type` header field if it's set explicitly (#40)

* Moved Luma nodes to nodes_luma.py (#47)

* Moved Recraft nodes to nodes_recraft.py (#48)

* Move and fix BFL nodes to node_bfl.py (#49)

* Move and edit Minimax node to nodes_minimax.py (#50)

* Add Recraft Text to Vector node, add Save SVG node to handle its output (#53)

* Added pixverse_template support to Pixverse Text to Video node (#54)

* Added Recraft Controls + Recraft Color RGB nodes (#57)

* split remaining nodes out of nodes_api, make utility lib, refactor ideogram (#61)

* Set request type explicitly (#66)

* Add `control_after_generate` to all seed inputs (#69)

* Fix bug: deleting `Content-Type` when property does not exist (#73)

* Add Pixverse and updated Kling types (#75)

* Added Recraft Style - Infinite Style Library node (#82)

* add ideogram v3 (#83)

* [Kling] Split Camera Control config to its own node (#81)

* Add Pika i2v and t2v nodes (#52)

* Remove Runway nodes (#88)

* Fix: Prompt text can't be validated in Kling nodes when using primitive nodes (#90)

* Update Pika Duration and Resolution options (#94)

* Removed Infinite Style Library until later (#99)

* fix multi image return (#101)

close #96

* Serve SVG files directly (#107)

* Add a bunch of nodes, 3 ready to use, the rest waiting for endpoint support (#108)

* Revert "Serve SVG files directly" (#111)

* Expose 4 remaining Recraft nodes (#112)

* [Kling] Add `Duration` and `Video ID` outputs (#105)

* Add Kling nodes: camera control, start-end frame, lip-sync, video extend (#115)

* Fix error for Recraft ImageToImage error for nonexistent random_seed param (#118)

* Add remaining Pika nodes (#119)

* Make controls input work for Recraft Image to Image node (#120)

* Fix: Nested `AnyUrl` in request model cannot be serialized (Kling, Runway) (#129)

* Show errors and API output URLs to the user (change log levels) (#131)

* Apply small fixes and most prompt validation (if needed to avoid API error) (#135)

* Node name/category modifications (#140)

* Add back Recraft Style - Infinite Style Library node (#141)

* [Kling] Fix: Correct/verify supported subset of input combos in Kling nodes (#149)

* Remove pixverse_template from PixVerse Transition Video node (#155)

* Use 3.9 compat syntax (#164)

* Handle Comfy API key based authorizaton (#167)

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* [BFL] Print download URL of successful task result directly on nodes (#175)

* Show output URL and progress text on Pika nodes (#168)

* [Ideogram] Print download URL of successful task result directly on nodes (#176)

* [Kling] Print download URL of successful task result directly on nodes (#181)

* Merge upstream may 14 25 (#186)

Co-authored-by: comfyanonymous <comfyanonymous@protonmail.com>
Co-authored-by: AustinMroz <austinmroz@utexas.edu>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Benjamin Lu <benceruleanlu@proton.me>
Co-authored-by: Andrew Kvochko <kvochko@users.noreply.github.com>
Co-authored-by: Pam <42671363+pamparamm@users.noreply.github.com>
Co-authored-by: chaObserv <154517000+chaObserv@users.noreply.github.com>
Co-authored-by: Yoland Yan <4950057+yoland68@users.noreply.github.com>
Co-authored-by: guill <guill@users.noreply.github.com>
Co-authored-by: Chenlei Hu <hcl@comfy.org>
Co-authored-by: Terry Jia <terryjia88@gmail.com>
Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com>
Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>
Co-authored-by: liesen <liesen.dev@gmail.com>
Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Robin Huang <robin.j.huang@gmail.com>
Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com>
Co-authored-by: blepping <157360029+blepping@users.noreply.github.com>

* Update instructions on how to develop API Nodes. (#171)

* Add Runway FLF and I2V nodes (#187)

* Add OpenAI chat node (#188)

* Update README.

* Add Google Gemini API node (#191)

* Add Runway Gen 4 Text to Image Node (#193)

* [Runway, Gemini] Update node display names and attributes (#194)

* Update path from "image-to-video" to "image_to_video" (#197)

* [Runway] Split I2V nodes into separate gen3 and gen4 nodes (#198)

* Update runway i2v ratio enum (#201)

* Rodin3D: implement Rodin3D API Nodes (#190)

Co-authored-by: WhiteGiven <c15838568211@163.com>
Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Add Tripo Nodes. (#189)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Change casing of categories "3D"  => "3d" (#208)

* [tripo] fix negtive_prompt and mv2model (#212)

* [tripo] set default param to None (#215)

* Add description and tooltip to Tripo Refine model. (#218)

* Update.

* Fix rebase errors.

* Fix rebase errors.

* Update templates.

* Bump frontend.

* Add file type info for file inputs.

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Chenlei Hu <hcl@comfy.org>
Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com>
Co-authored-by: comfyanonymous <comfyanonymous@protonmail.com>
Co-authored-by: AustinMroz <austinmroz@utexas.edu>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Benjamin Lu <benceruleanlu@proton.me>
Co-authored-by: Andrew Kvochko <kvochko@users.noreply.github.com>
Co-authored-by: Pam <42671363+pamparamm@users.noreply.github.com>
Co-authored-by: chaObserv <154517000+chaObserv@users.noreply.github.com>
Co-authored-by: Yoland Yan <4950057+yoland68@users.noreply.github.com>
Co-authored-by: guill <guill@users.noreply.github.com>
Co-authored-by: Terry Jia <terryjia88@gmail.com>
Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com>
Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>
Co-authored-by: liesen <liesen.dev@gmail.com>
Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Co-authored-by: blepping <157360029+blepping@users.noreply.github.com>
Co-authored-by: Changrz <51637999+WhiteGiven@users.noreply.github.com>
Co-authored-by: WhiteGiven <c15838568211@163.com>
Co-authored-by: seed93 <liangding1990@163.com>
2025-05-27 03:00:58 -04:00
filtered
3a10b9641c [BugFix] Update frontend to 1.20.6 (#8296) 2025-05-27 02:47:06 -04:00
comfyanonymous
89a84e32d2 Disable initial GPU load when novram is used. (#8294) 2025-05-26 16:39:27 -04:00
comfyanonymous
e5799c4899 Enable pytorch attention by default on AMD gfx1151 (#8282) 2025-05-26 04:29:25 -04:00
comfyanonymous
a0651359d7 Return proper error if diffusion model not detected properly. (#8272) 2025-05-25 05:28:11 -04:00
comfyanonymous
ad3bd8aa49 ComfyUI version 0.3.36 2025-05-24 17:30:37 -04:00
comfyanonymous
5a87757ef9 Better error if sageattention is installed but a dependency is missing. (#8264) 2025-05-24 06:43:12 -04:00
Christian Byrne
464aece92b update frontend package to v1.20.5 (#8260) 2025-05-23 21:53:49 -07:00
comfyanonymous
0b50d4c0db Add argument to explicitly enable fp8 compute support. (#8257)
This can be used to test if your current GPU/pytorch version supports fp8 matrix mult in combination with --fast or the fp8_e4m3fn_fast dtype.
2025-05-23 17:43:50 -04:00
drhead
30b2eb8a93 create arange on-device (#8255) 2025-05-23 16:15:06 -04:00
comfyanonymous
f85c08df06 Make VACE conditionings stackable. (#8240) 2025-05-22 19:22:26 -04:00
comfyanonymous
4202e956a0 Add append feature to conditioning_set_values (#8239)
Refactor unclipconditioning node.
2025-05-22 08:11:13 -04:00
Terry Jia
b838c36720 remove mtl from 3d model file list (#8192) 2025-05-22 08:08:36 -04:00
Chenlei Hu
fc39184ea9 Update frontend to 1.20 (#8232) 2025-05-22 02:24:36 -04:00
ComfyUI Wiki
ded60c33a0 Update templates to 0.1.18 (#8224) 2025-05-21 11:40:08 -07:00
Michael Abrahams
8bb858e4d3 Improve performance with large number of queued prompts (#8176)
* get_current_queue_volatile

* restore get_current_queue method

* remove extra import
2025-05-21 05:14:17 -04:00
编程界的小学生
57893c843f Code Optimization and Issues Fixes in ComfyUI server (#8196)
* Update server.py

* Update server.py
2025-05-21 04:59:42 -04:00
Jedrzej Kosinski
65da29aaa9 Make torch.compile LoRA/key-compatible (#8213)
* Make torch compile node use wrapper instead of object_patch for the entire diffusion_models object, allowing key assotiations on diffusion_models to not break (loras, getting attributes, etc.)

* Moved torch compile code into comfy_api so it can be used by custom nodes with a degree of confidence

* Refactor set_torch_compile_wrapper to support a list of keys instead of just diffusion_model, as well as additional torch.compile args

* remove unused import

* Moved torch compile kwargs to be stored in model_options instead of attachments; attachments are more intended for things to be 'persisted', AKA not deepcopied

* Add some comments

* Remove random line of code, not sure how it got there
2025-05-21 04:56:56 -04:00
comfyanonymous
10024a38ea ComfyUI version v0.3.35 2025-05-21 04:50:37 -04:00
94 changed files with 162374 additions and 3519 deletions

View File

@@ -15,6 +15,14 @@ body:
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Expected Behavior

View File

@@ -11,6 +11,14 @@ body:
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Your question

View File

@@ -5,20 +5,20 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne

View File

@@ -6,6 +6,7 @@
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Twitter][twitter-shield]][twitter-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
@@ -20,6 +21,8 @@
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
@@ -62,12 +65,16 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -95,7 +102,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -268,6 +276,8 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs

84
alembic.ini Normal file
View File

@@ -0,0 +1,84 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME

4
alembic_db/README.md Normal file
View File

@@ -0,0 +1,4 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

64
alembic_db/env.py Normal file
View File

@@ -0,0 +1,64 @@
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic_db/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

112
app/database/db.py Normal file
View File

@@ -0,0 +1,112 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

14
app/database/models.py Normal file
View File

@@ -0,0 +1,14 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

View File

@@ -16,26 +16,17 @@ from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
{get_missing_requirements_message()}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
@@ -48,7 +39,7 @@ def check_frontend_version():
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
with open(requirements_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning(
@@ -121,9 +112,22 @@ class FrontEndProvider:
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
@cached_property
def latest_prerelease(self) -> Release:
"""Get the latest pre-release version - even if it's older than the latest release"""
release = [release for release in self.all_releases if release["prerelease"]]
if not release:
raise ValueError("No pre-releases found")
# GitHub returns releases in reverse chronological order, so first is latest
return release[0]
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
elif version == "prerelease":
return self.latest_prerelease
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
@@ -205,6 +209,19 @@ comfyui-workflow-templates is not installed.
""".strip()
)
@classmethod
def embedded_docs_path(cls) -> str:
"""Get the path to embedded documentation"""
try:
import comfyui_embedded_docs
return str(
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
return None
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
@@ -217,7 +234,7 @@ comfyui-workflow-templates is not installed.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")

View File

@@ -88,6 +88,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -202,6 +203,11 @@ parser.add_argument(
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"

View File

@@ -24,6 +24,10 @@ class CONDRegular:
conds.append(x.cond)
return torch.cat(conds)
def size(self):
return list(self.cond.size())
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond
@@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular):
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
@@ -78,3 +83,48 @@ class CONDConstant(CONDRegular):
def concat(self, others):
return self.cond
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@@ -390,8 +390,9 @@ class ControlLora(ControlNet):
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
if (k not in {"lora_controlnet"}):
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)

View File

@@ -1,4 +1,5 @@
import math
from functools import partial
from scipy import integrate
import torch
@@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler:
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(model_sampling, comfy.model_sampling.CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -682,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -693,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
# Denoising step
x = denoised
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
alpha_s = sigmas[i] * lambda_s.exp()
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
lambda_s_1_ = sd.log().neg()
h_ = lambda_s_1_ - lambda_s
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
if eta > 0 and s_noise > 0:
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
lambda_t_ = sd.log().neg()
h_ = lambda_t_ - lambda_s
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
return x
@@ -753,6 +793,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
@@ -768,9 +809,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None
h_last = None
h = None
h, h_last = None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -781,26 +825,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -814,6 +861,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@@ -825,13 +876,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
# Denoising step
x = denoised
else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
@@ -840,20 +894,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d
x = x + (alpha_t * phi_2) * d
if eta:
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
@@ -863,6 +919,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -872,6 +929,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
@@ -1449,12 +1507,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1462,6 +1520,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
@@ -1469,80 +1532,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
s = t + r * h
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s = s.neg().exp()
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x

View File

@@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@@ -178,6 +176,6 @@ class LastLayer(nn.Module):
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
x = self.linear(x)
return x

View File

@@ -163,7 +163,7 @@ class Chroma(nn.Module):
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too

View File

@@ -26,16 +26,6 @@ from torch import nn
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()

View File

@@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
enable_fps_modulation: bool = True,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
self.enable_fps_modulation = enable_fps_modulation
dim = head_dim
dim_h = dim // 6 * 2
@@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
B, T, H, W, _ = B_T_H_W_C
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
if fps is None or self.enable_fps_modulation is False: # image case
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)

View File

@@ -0,0 +1,864 @@
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
import logging
from typing import Callable, Optional, Tuple
import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
# ---------------------- Feed Forward Network -----------------------
class GPT2FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
super().__init__()
self.activation = nn.GELU()
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
self._layer_id = None
self._dim = d_model
self._hidden_dim = d_ff
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
attention, and rearranges the output back to the original format.
The input tensor names use the following dimension conventions:
- B: batch size
- S: sequence length
- H: number of attention heads
- D: head dimension
Args:
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
Returns:
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):
"""
A flexible attention module supporting both self-attention and cross-attention mechanisms.
This module implements a multi-head attention layer that can operate in either self-attention
or cross-attention mode. The mode is determined by whether a context dimension is provided.
The implementation uses scaled dot-product attention and supports optional bias terms and
dropout regularization.
Args:
query_dim (int): The dimensionality of the query vectors.
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
If None, the module operates in self-attention mode using query_dim. Default: None
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
head_dim (int, optional): The dimension of each attention head. Default: 64
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
Examples:
>>> # Self-attention with 512 dimensions and 8 heads
>>> self_attn = Attention(query_dim=512)
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
>>> out = self_attn(x) # (32, 16, 512)
>>> # Cross-attention
>>> cross_attn = Attention(query_dim=512, context_dim=256)
>>> query = torch.randn(32, 16, 512)
>>> context = torch.randn(32, 8, 256)
>>> out = cross_attn(query, context) # (32, 16, 512)
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
n_heads: int = 8,
head_dim: int = 64,
dropout: float = 0.0,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
logging.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{n_heads} heads with a dimension of {head_dim}."
)
self.is_selfattn = context_dim is None # self attention
context_dim = query_dim if context_dim is None else context_dim
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.v_norm = nn.Identity()
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
def compute_qkv(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_proj(x)
context = x if context is None else context
k = self.k_proj(context)
v = self.v_proj(context)
q, k, v = map(
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
(q, k, v),
)
def apply_norm_and_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.q_norm(q)
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
class Timesteps(nn.Module):
def __init__(self, num_channels: int):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
timesteps = timesteps_B_T.flatten().float()
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.in_dim = in_features
self.out_dim = out_features
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_T_3D = emb
emb_B_T_D = sample
else:
adaln_lora_B_T_3D = None
emb_B_T_D = emb
return emb_B_T_D, adaln_lora_B_T_3D
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size: int,
temporal_patch_size: int,
in_channels: int = 3,
out_channels: int = 768,
device=None, dtype=None, operations=None
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert (
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size: int,
spatial_patch_size: int,
temporal_patch_size: int,
out_channels: int,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
if use_adaln_lora:
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
)
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
scale_B_T_D, "b t d -> b t 1 1 d"
)
def _fn(
_x_B_T_H_W_D: torch.Tensor,
_norm_layer: nn.Module,
_scale_B_T_1_1_D: torch.Tensor,
_shift_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
return x_B_T_H_W_O
class Block(nn.Module):
"""
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
The block applies the following sequence:
1. Self-attention with AdaLN modulation
2. Cross-attention with AdaLN modulation
3. MLP with AdaLN modulation
Each component uses skip connections and layer normalization.
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.x_dim = x_dim
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
)
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
self.use_adaln_lora = use_adaln_lora
if self.use_adaln_lora:
self.adaln_modulation_self_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_cross_attn = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
self.adaln_modulation_mlp = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
)
else:
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
B, T, H, W, D = x_B_T_H_W_D.shape
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_self_attn,
scale_self_attn_B_T_1_1_D,
shift_self_attn_B_T_1_1_D,
)
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T,
h=H,
w=W,
)
return _result_B_T_H_W_D
result_B_T_H_W_D = _x_fn(
x_B_T_H_W_D,
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
self.layer_norm_mlp,
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
return x_B_T_H_W_D
class MiniTrainDIT(nn.Module):
"""
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
min_fps (int): Minimum frames per second.
max_fps (int): Maximum frames per second.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: int, # tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
# cross attention settings
crossattn_emb_channels: int = 1024,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
min_fps: int = 1,
max_fps: int = 30,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.min_fps = min_fps
self.max_fps = max_fps
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.build_pos_embed(device=device, dtype=dtype)
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.Sequential(
Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
)
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
device=device, dtype=dtype, operations=operations,
)
self.blocks = nn.ModuleList(
[
Block(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_blocks)
]
)
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
device=device, dtype=dtype, operations=operations,
)
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
def build_pos_embed(self, device=None, dtype=None) -> None:
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
max_fps=self.max_fps,
min_fps=self.min_fps,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
enable_fps_modulation=self.rope_enable_fps_modulation,
device=device,
)
self.pos_embedder = cls_type(
**kwargs, # type: ignore
)
if self.extra_per_block_abs_pos_emb:
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs, # type: ignore
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is None:
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
else:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
x_B_C_Tt_Hp_Wp = rearrange(
x_B_T_H_W_M,
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
t=self.patch_temporal,
)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
x_B_C_T_H_W = x
timesteps_B_T = timesteps
crossattn_emb = context
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
"""
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x_B_C_T_H_W,
fps=fps,
padding_mask=padding_mask,
)
if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
# for logging purpose
affline_scale_log_info = {}
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
self.affline_scale_log_info = affline_scale_log_info
self.affline_emb = t_embedding_B_T_D
self.crossattn_emb = crossattn_emb
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
block_kwargs = {
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
**block_kwargs,
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -121,6 +121,11 @@ class ControlNetFlux(Flux):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
else:
y = y[:, :self.params.vec_in_dim]
# running on sequences img
img = self.img_in(img)
@@ -174,7 +179,7 @@ class ControlNetFlux(Flux):
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))

View File

@@ -118,7 +118,7 @@ class Modulation(nn.Module):
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
return torch.addcmul(m_add, tensor, m_mult)
else:
return tensor * m_mult
else:

View File

@@ -101,6 +101,10 @@ class Flux(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -155,6 +159,9 @@ class Flux(nn.Module):
if add is not None:
img += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
@@ -188,20 +195,50 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]

View File

@@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
def __init__(self, sample: bool = False):
super().__init__()
self.sample = sample
@@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
return z, None
class AbstractAutoencoder(torch.nn.Module):

View File

@@ -20,8 +20,11 @@ if model_management.xformers_enabled():
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
except ModuleNotFoundError as e:
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():
@@ -750,7 +753,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
@@ -790,12 +793,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x
return x

View File

@@ -31,7 +31,7 @@ def dynamic_slice(
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
return x[slicing]
class AttnChunk(NamedTuple):

View File

@@ -0,0 +1,469 @@
# Original code: https://github.com/VectorSpaceLab/OmniGen2
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.act = nn.SiLU()
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class LuminaRMSNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x) * (1 + emb)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
self.caption_embedder = nn.Sequential(
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
class Attention(nn.Module):
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = query.view(batch_size, -1, self.heads, self.dim_head)
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
class OmniGen2TransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.attn = Attention(
query_dim=dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
dtype=dtype, device=device, operations=operations,
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
dtype=dtype, device=device, operations=operations
)
if modulation:
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
else:
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
p = self.patch_size
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
max_seq_len = max(seq_lengths)
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
pe_shift = cap_seq_len
pe_shift_len = cap_seq_len
if ref_img_sizes[i] is not None:
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
H, W = ref_img_size
ref_H_tokens, ref_W_tokens = H // p, W // p
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
pe_shift += max(ref_H_tokens, ref_W_tokens)
pe_shift_len += ref_img_len
H, W = img_sizes[i]
H_tokens, W_tokens = H // p, W // p
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = encoder_seq_len
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
ref_img_freqs_cis_shape = list(freqs_cis.shape)
ref_img_freqs_cis_shape[1] = max_ref_img_len
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
class OmniGen2Transformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.layers = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.patch_size
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
flat_ref_img_hidden_states = None
if ref_image_hidden_states is not None:
imgs = []
for ref_img in ref_image_hidden_states:
B, C, H, W = ref_img.size()
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
imgs.append(ref_img)
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
img = hidden_states
B, C, H, W = img.size()
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
return (
flat_hidden_states, flat_ref_img_hidden_states,
None, None,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if ref_image_hidden_states is not None:
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@@ -539,13 +539,20 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
@@ -635,7 +642,7 @@ class VaceWanModel(WanModel):
t,
context,
vace_context,
vace_strength=1.0,
vace_strength,
clip_fea=None,
freqs=None,
transformer_options={},
@@ -661,8 +668,11 @@ class VaceWanModel(WanModel):
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
orig_shape = list(vace_context.shape)
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
c = c.flatten(2).transpose(1, 2)
c = list(c.split(orig_shape[0], dim=0))
# arguments
x_orig = x
@@ -682,8 +692,9 @@ class VaceWanModel(WanModel):
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength[iii]
del c_skip
# head
x = self.head(x, e)

View File

@@ -283,8 +283,9 @@ def model_lora_keys_unet(model, key_map={}):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:

View File

@@ -34,12 +34,14 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.model_management
import comfy.patcher_extension
@@ -48,6 +50,7 @@ import comfy.ops
from enum import Enum
from . import utils
import comfy.latent_formats
import comfy.model_sampling
import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -63,38 +66,39 @@ class ModelType(Enum):
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
FLOW_COSMOS = 10
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
s = comfy.model_sampling.ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
c = comfy.model_sampling.EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
c = comfy.model_sampling.V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
c = comfy.model_sampling.EPS
s = comfy.model_sampling.StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
c = comfy.model_sampling.EDM
s = comfy.model_sampling.ModelSamplingContinuousEDM
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION
s = ModelSamplingContinuousV
c = comfy.model_sampling.V_PREDICTION
s = comfy.model_sampling.ModelSamplingContinuousV
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSampling(s, c):
pass
@@ -102,6 +106,13 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
return extra
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
@@ -135,6 +146,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -164,9 +176,14 @@ class BaseModel(torch.nn.Module):
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = convert_tensor(extra, dtype)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
extra = ex
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
@@ -325,19 +342,28 @@ class BaseModel(torch.nn.Module):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
def memory_required(self, input_shape):
def memory_required(self, input_shape, cond_shapes={}):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
def extra_conds_shapes(self, **kwargs):
return {}
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
@@ -790,6 +816,7 @@ class PixArt(BaseModel):
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
self.memory_usage_factor_conds = ("ref_latents",)
def concat_cond(self, **kwargs):
try:
@@ -850,8 +877,23 @@ class Flux(BaseModel):
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@@ -976,6 +1018,45 @@ class CosmosVideo(BaseModel):
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class CosmosPredict2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
self.image_to_video = image_to_video
if self.image_to_video:
self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
if denoise_mask.ndim <= 4:
return timestep
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
@@ -1047,6 +1128,11 @@ class WAN21(BaseModel):
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
time_dim_concat = kwargs.get("time_dim_concat", None)
if time_dim_concat is not None:
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
return out
@@ -1062,20 +1148,25 @@ class WAN21_Vace(WAN21):
vace_frames = kwargs.get("vace_frames", None)
if vace_frames is None:
noise_shape[1] = 32
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
for i in range(0, vace_frames.shape[1], 16):
vace_frames = vace_frames.clone()
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
mask = kwargs.get("vace_mask", None)
if mask is None:
noise_shape[1] = 64
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vace_frames_out.append(vf)
vace_strength = kwargs.get("vace_strength", 1.0)
vace_frames = torch.stack(vace_frames_out, dim=1)
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
@@ -1156,3 +1247,33 @@ class ACEStep(BaseModel):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out

View File

@@ -407,6 +407,78 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = True
dit_config["pos_emb_interpolation"] = "crop"
dit_config["min_fps"] = 1
dit_config["max_fps"] = 30
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 2048:
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 16
elif dit_config["model_channels"] == 5120:
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
if dit_config["in_channels"] == 16:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17: # img to video
if dit_config["model_channels"] == 2048:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["model_channels"] == 5120:
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
dit_config["extra_t_extrapolation_ratio"] = 1.0
dit_config["rope_enable_fps_modulation"] = False
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
dit_config["axes_dim_rope"] = [40, 40, 40]
dit_config["axes_lens"] = [1024, 1664, 1664]
dit_config["ffn_dim_multiplier"] = None
dit_config["hidden_size"] = 2520
dit_config["in_channels"] = 16
dit_config["multiple_of"] = 256
dit_config["norm_eps"] = 1e-05
dit_config["num_attention_heads"] = 21
dit_config["num_kv_heads"] = 7
dit_config["num_layers"] = 32
dit_config["num_refiner_layers"] = 2
dit_config["out_channels"] = None
dit_config["patch_size"] = 2
dit_config["text_feat_dim"] = 2048
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -620,6 +692,9 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
if "conv_in.weight" not in state_dict:
return None
match = {}
transformer_depth = []

View File

@@ -295,14 +295,24 @@ except:
pass
SUPPORT_FP8_OPS = args.supports_fp8_compute
try:
if is_amd():
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
except:
pass
@@ -323,7 +333,7 @@ except:
pass
try:
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@@ -695,7 +705,7 @@ def unet_inital_load_device(parameters, dtype):
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY:
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
model_size = dtype_size(dtype) * parameters
@@ -1042,7 +1052,7 @@ def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if is_nvidia(): #pytorch flash attention only works on Nvidia
if is_nvidia():
return True
if is_intel_xpu():
return True
@@ -1058,7 +1068,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
upcast = True
if upcast:
@@ -1257,6 +1267,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
return False
@@ -1268,15 +1281,22 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
if torch_version_numeric < (2, 4):
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -17,23 +17,26 @@
"""
from __future__ import annotations
from typing import Optional, Callable
import torch
import collections
import copy
import inspect
import logging
import uuid
import collections
import math
import uuid
from typing import Callable, Optional
import torch
import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data):
crc = 0xFFFFFFFF

View File

@@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
@@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
if percent >= 1.0:
return 0.0
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
def timestep(self, sigma):
return sigma / (sigma + 1)
def sigma(self, timestep):
sigma_max = self.sigma_max
if timestep >= (sigma_max / (sigma_max + 1)):
return sigma_max
return timestep / (1 - timestep)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.utils
@@ -104,6 +106,21 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
for _, cs in conds.items():
for cond in cs:
for k, v in model.model.extra_conds_shapes(**cond).items():
cond_shapes[k].append(v)
if cond_shapes_min.get(k, None) is None:
cond_shapes_min[k] = [v]
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
cond_shapes_min[k] = [v]
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
@@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
real_model = model.model
return real_model, conds, models

View File

@@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
to_batch = batch_amount
break

View File

@@ -44,6 +44,7 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.model_patcher
import comfy.lora
@@ -754,6 +755,7 @@ class CLIPType(Enum):
HIDREAM = 14
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -773,6 +775,7 @@ class TEModel(Enum):
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -793,6 +796,8 @@ def detect_te_model(sd):
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
return TEModel.QWEN25_3B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -894,6 +899,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1081,7 +1089,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.
The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files
@@ -1139,7 +1168,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in unet: {}".format(left_over))
logging.info("left over keys in diffusion model: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
@@ -1147,7 +1176,7 @@ def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model

View File

@@ -462,7 +462,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
@@ -482,7 +482,8 @@ class SDTokenizer:
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if has_end_token:
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token

View File

@@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
from . import supported_models_base
from . import latent_formats
@@ -908,6 +909,48 @@ class CosmosI2V(CosmosT2V):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
class CosmosT2IPredict2(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 16,
}
sampling_settings = {
"sigma_data": 1.0,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
"in_channels": 17,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
return out
class Lumina2(supported_models_base.BASE):
unet_config = {
"image_model": "lumina2",
@@ -1139,6 +1182,41 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
class Omnigen2(supported_models_base.BASE):
unet_config = {
"image_model": "omnigen2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Omnigen2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
models += [SVD_img2vid]

View File

@@ -24,6 +24,24 @@ class Llama2Config:
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 11008
num_hidden_layers: int = 36
num_attention_heads: int = 16
num_key_value_heads: int = 2
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
@@ -40,6 +58,7 @@ class Gemma2_2B_Config:
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -98,9 +117,9 @@ class Attention(nn.Module):
self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
@@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):

View File

@@ -1,25 +0,0 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@@ -0,0 +1,44 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<|img|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151666": {
"content": "<|endofimg|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151667": {
"content": "<|meta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151668": {
"content": "<|endofmeta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"extra_special_tokens": {},
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"processor_class": "Qwen2_5_VLProcessor",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
from .base import WeightAdapterBase
from .base import WeightAdapterBase, WeightAdapterTrainBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
@@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
] + [a.__name__ for a in adapters]

View File

@@ -12,12 +12,20 @@ class WeightAdapterBase:
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
@classmethod
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
weight: The original weight tensor to be modified.
*args: Additional arguments for configuration, such as rank, alpha etc.
"""
raise NotImplementedError
def calculate_weight(
self,
weight,
@@ -33,10 +41,22 @@ class WeightAdapterBase:
class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
def __init__(self):
super().__init__()
# [TODO] Collaborate with LoRA training PR #7032
def __call__(self, w):
"""
w: The original weight tensor to be modified.
"""
raise NotImplementedError
def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")
def move_to(self, device):
self.to(device)
return self.passive_memory_usage()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
@@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def tucker_weight_from_conv(up, down, mid):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)

View File

@@ -3,7 +3,56 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
pad_tensor_to_shape,
tucker_weight_from_conv,
)
class LoraDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
mat1, mat2, alpha, mid, dora_scale, reshape = weights
out_dim, rank = mat1.shape[0], mat1.shape[1]
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
self.lora_down = layer(in_dim, rank, bias=False)
self.lora_up.weight.data.copy_(mat1)
self.lora_down.weight.data.copy_(mat2)
if mid is not None:
self.lora_mid = layer(mid, rank, bias=False)
self.lora_mid.weight.data.copy_(mid)
else:
self.lora_mid = None
self.rank = rank
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
if self.lora_mid is None:
diff = self.lora_up.weight @ self.lora_down.weight
else:
diff = tucker_weight_from_conv(
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
)
scale = self.alpha / self.rank
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoRAAdapter(WeightAdapterBase):
@@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
def to_train(self):
return LoraDiff(self.weights)
@classmethod
def load(
cls,

View File

@@ -0,0 +1,5 @@
from .torch_compile import set_torch_compile_wrapper
__all__ = [
"set_torch_compile_wrapper",
]

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
import torch
import comfy.utils
from comfy.patcher_extension import WrappersMP
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor
COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
'''
Create a wrapper that will refer to the compiled_diffusion_model.
'''
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
try:
orig_modules = {}
for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value)
return executor(*args, **kwargs)
finally:
for key, value in orig_modules.items():
comfy.utils.set_attr(executor.class_obj, key, value)
return apply_torch_compile_wrapper
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
keys: list[str]=["diffusion_model"], *args, **kwargs):
'''
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
When a list of keys is provided, it will perform torch.compile on only the selected modules.
'''
# clear out any other torch.compile wrappers
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
# if no keys, default to 'diffusion_model'
if not keys:
keys = ["diffusion_model"]
# create kwargs dict that can be referenced later
compile_kwargs = {
"backend": backend,
"options": options,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
# get a dict of compiled keys
compiled_modules = {}
for key in keys:
compiled_modules[key] = torch.compile(
model=model.get_model_object(key),
**compile_kwargs,
)
# add torch.compile wrapper
wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules,
)
# store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
# keep compile kwargs for reference
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs

View File

@@ -18,6 +18,8 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to
python run main.py --comfy-api-base https://stagingapi.comfy.org
```
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
### Redocly Instructions
@@ -28,7 +30,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
```bash
# Download the OpenAPI file from prod server.
# Download the OpenAPI file from staging server.
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
# Filter out unneeded API definitions.
@@ -39,3 +41,25 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```
# Merging to Master
Before merging to comfyanonymous/ComfyUI master, follow these steps:
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
1. Make sure the ComfyUI API is deployed to prod with your changes.
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
```bash
# Download the OpenAPI file from prod server.
curl -o openapi.yaml https://api.comfy.org/openapi
# Filter out unneeded API definitions.
npm install -g @redocly/cli
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
# Generate the pydantic datamodels for validation.
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import io
import logging
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
@@ -214,6 +215,7 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
image_bytesio = download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response: requests.Response) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content))
@@ -318,11 +320,27 @@ def tensor_to_data_uri(
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
@@ -357,9 +375,33 @@ def upload_file_to_comfyapi(
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
@@ -461,7 +503,7 @@ def audio_ndarray_to_bytesio(
def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
@@ -488,8 +530,25 @@ def upload_audio_to_comfyapi(
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
def upload_images_to_comfyapi(
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
@@ -554,17 +613,24 @@ def upload_images_to_comfyapi(
return download_urls
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
upscale_method="nearest-exact", crop="disabled",
allow_gradient=True, add_channel_dim=False):
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1,1)
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1,-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
@@ -572,12 +638,41 @@ def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
return mask
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
if not string:
raise Exception(f"Field '{field_name}' cannot be empty.")
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)

File diff suppressed because it is too large Load Diff

View File

@@ -108,6 +108,24 @@ class BFLFluxProGenerateRequest(BaseModel):
# )
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(

View File

@@ -139,7 +139,7 @@ class EmptyRequest(BaseModel):
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: str | None = Field(
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@@ -327,7 +327,9 @@ class ApiClient:
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
url = urljoin(self.base_url, path)
# Use urljoin but ensure path is relative to avoid absolute path behavior
relative_path = path.lstrip('/')
url = urljoin(self.base_url, relative_path)
self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
class Rodin3DGenerateRequest(BaseModel):
seed: int = Field(..., description="seed_")
tier: str = Field(..., description="Tier of generation.")
material: str = Field(..., description="The material type.")
quality: str = Field(..., description="The generation quality of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST")
subscription_key: str = Field(..., description="subscription key")
class Rodin3DGenerateResponse(BaseModel):
message: Optional[str] = Field(None, description="Return message.")
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
submit_time: Optional[str] = Field(None, description="Submit Time")
uuid: Optional[str] = Field(None, description="Task str")
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
class JobStatus(str, Enum):
"""
Status for jobs
"""
Done = "Done"
Failed = "Failed"
Generating = "Generating"
Waiting = "Waiting"
class Rodin3DCheckStatusRequest(BaseModel):
subscription_key: str = Field(..., description="subscription from generate endpoint")
class JobItem(BaseModel):
uuid: str = Field(..., description="uuid")
status: JobStatus = Field(...,description="Status Currently")
class Rodin3DCheckStatusResponse(BaseModel):
jobs: List[JobItem] = Field(..., description="Job status List")
class Rodin3DDownloadRequest(BaseModel):
task_uuid: str = Field(..., description="Task str")
class RodinResourceItem(BaseModel):
url: str = Field(..., description="Download Url")
name: str = Field(..., description="File name with ext")
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom"
OBJECT_CLAY = "object:clay"
OBJECT_STEAMPUNK = "object:steampunk"
OBJECT_CHRISTMAS = "object:christmas"
OBJECT_BARBIE = "object:barbie"
GOLD = "gold"
ANCIENT_BRONZE = "ancient_bronze"
NONE = "None"
class TripoTaskType(str, Enum):
TEXT_TO_MODEL = "text_to_model"
IMAGE_TO_MODEL = "image_to_model"
MULTIVIEW_TO_MODEL = "multiview_to_model"
TEXTURE_MODEL = "texture_model"
REFINE_MODEL = "refine_model"
ANIMATE_PRERIGCHECK = "animate_prerigcheck"
ANIMATE_RIG = "animate_rig"
ANIMATE_RETARGET = "animate_retarget"
STYLIZE_MODEL = "stylize_model"
CONVERT_MODEL = "convert_model"
class TripoTextureAlignment(str, Enum):
ORIGINAL_IMAGE = "original_image"
GEOMETRY = "geometry"
class TripoOrientation(str, Enum):
ALIGN_IMAGE = "align_image"
DEFAULT = "default"
class TripoOutFormat(str, Enum):
GLB = "glb"
FBX = "fbx"
class TripoTopology(str, Enum):
BIP = "bip"
QUAD = "quad"
class TripoSpec(str, Enum):
MIXAMO = "mixamo"
TRIPO = "tripo"
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
VOXEL = "voxel"
VORONOI = "voronoi"
MINECRAFT = "minecraft"
class TripoConvertFormat(str, Enum):
GLTF = "GLTF"
USDZ = "USDZ"
FBX = "FBX"
OBJ = "OBJ"
STL = "STL"
_3MF = "3MF"
class TripoTextureFormat(str, Enum):
BMP = "BMP"
DPX = "DPX"
HDR = "HDR"
JPEG = "JPEG"
OPEN_EXR = "OPEN_EXR"
PNG = "PNG"
TARGA = "TARGA"
TIFF = "TIFF"
WEBP = "WEBP"
class TripoTaskStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
UNKNOWN = "unknown"
BANNED = "banned"
EXPIRED = "expired"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
class TripoUrlReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
url: str
class TripoObjectStorage(BaseModel):
bucket: str
key: str
class TripoObjectReference(BaseModel):
type: str
object: TripoObjectStorage
class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
image_seed: Optional[int] = Field(None, description='The seed for the text')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoImageToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
class TripoRefineModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
class TripoAnimatePrerigcheckRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
class TripoAnimateRigRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
class TripoAnimateRetargetRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
animation: TripoAnimation = Field(..., description='The animation to apply')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
class TripoStylizeModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
original_model_task_id: str = Field(..., description='The task ID of the original model')
block_size: Optional[int] = Field(80, description='The block size for stylization')
class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
class TripoTaskRequest(RootModel):
root: Union[
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimatePrerigcheckRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoStylizeModelRequest,
TripoConvertModelRequest
]
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoTask = Field(..., description='The task data')
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: Dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')
frozen: float = Field(..., description='The frozen balance')
class TripoBalanceResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoBalanceData = Field(..., description='The balance data')
class TripoErrorResponse(BaseModel):
code: int = Field(..., description='The error code')
message: str = Field(..., description='The error message')
suggestion: str = Field(..., description='The suggestion for fixing the error')

View File

@@ -1,6 +1,6 @@
import io
from inspect import cleandoc
from typing import Union
from typing import Union, Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
@@ -9,6 +9,7 @@ from comfy_api_nodes.apis.bfl_api import (
BFLFluxCannyImageRequest,
BFLFluxDepthImageRequest,
BFLFluxProGenerateRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProUltraGenerateRequest,
BFLFluxProGenerateResponse,
)
@@ -269,6 +270,145 @@ class FluxProUltraImageNode(ComfyNodeABC):
return (output_image,)
class FluxKontextProImageNode(ComfyNodeABC):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation - specify what and how to edit.",
},
),
"aspect_ratio": (
IO.STRING,
{
"default": "16:9",
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
},
),
"guidance": (
IO.FLOAT,
{
"default": 3.0,
"min": 0.1,
"max": 99.0,
"step": 0.1,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 1,
"max": 150,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 1234,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
},
"optional": {
"input_image": (IO.IMAGE,),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
def api_call(
self,
prompt: str,
aspect_ratio: str,
guidance: float,
steps: int,
input_image: Optional[torch.Tensor]=None,
seed=0,
prompt_upsampling=False,
unique_id: Union[str, None] = None,
**kwargs,
):
aspect_ratio = validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
)
if input_image is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=self.BFL_PATH,
method=HttpMethod.POST,
request_model=BFLFluxKontextProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxKontextProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
guidance=round(guidance, 1),
steps=steps,
seed=seed,
aspect_ratio=aspect_ratio,
input_image=(
input_image
if input_image is None
else convert_image_to_base64(input_image)
)
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "")
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
class FluxProImageNode(ComfyNodeABC):
"""
@@ -914,6 +1054,8 @@ class FluxProDepthNode(ComfyNodeABC):
NODE_CLASS_MAPPINGS = {
"FluxProUltraImageNode": FluxProUltraImageNode,
# "FluxProImageNode": FluxProImageNode,
"FluxKontextProImageNode": FluxKontextProImageNode,
"FluxKontextMaxImageNode": FluxKontextMaxImageNode,
"FluxProExpandNode": FluxProExpandNode,
"FluxProFillNode": FluxProFillNode,
"FluxProCannyNode": FluxProCannyNode,
@@ -924,6 +1066,8 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
# "FluxProImageNode": "Flux 1.1 [pro] Image",
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
"FluxProExpandNode": "Flux.1 Expand Image",
"FluxProFillNode": "Flux.1 Fill Image",
"FluxProCannyNode": "Flux.1 Canny Control Image",

View File

@@ -0,0 +1,446 @@
"""
API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
import os
from enum import Enum
from typing import Optional, Literal
import torch
import folder_paths
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
from comfy_api_nodes.apis import (
GeminiContent,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiInlineData,
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
validate_string,
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
)
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
class GeminiModel(str, Enum):
"""
Gemini Model Names allowed by comfy-api
"""
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
class GeminiNode(ComfyNodeABC):
"""
Node to generate text responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, audio, video, files) to generate coherent
text responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"audio": (
IO.AUDIO,
{
"tooltip": "Optional audio to use as context for the model.",
"default": None,
},
),
"video": (
IO.VIDEO,
{
"tooltip": "Optional video to use as context for the model.",
"default": None,
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
RETURN_TYPES = ("STRING",)
FUNCTION = "api_call"
CATEGORY = "api node/text/Gemini"
API_NODE = True
def get_parts_from_response(
self, response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in self.get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = self.get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
Args:
video_input: Video tensor from ComfyUI.
**kwargs: Additional arguments to pass to the conversion function.
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
codec=VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
Args:
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
Returns:
List of GeminiPart objects containing the encoded audio.
"""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = {
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
"sample_rate": audio_input["sample_rate"],
}
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(self, text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def api_call(
self,
prompt: str,
model: GeminiModel,
images: Optional[IO.IMAGE] = None,
audio: Optional[IO.AUDIO] = None,
video: Optional[IO.VIDEO] = None,
files: Optional[list[GeminiPart]] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [self.create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = self.create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
if video is not None:
parts.extend(self.create_video_parts(video))
if files is not None:
parts.extend(files)
# Create response
response = SynchronousOperation(
endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
)
]
),
auth_kwargs=kwargs,
).execute()
# Get result output
output_text = self.get_text_from_response(response)
if unique_id and output_text:
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
return (output_text or "Empty response from Gemini model...",)
class GeminiInputFiles(ComfyNodeABC):
"""
Loads and formats input files for use with the Gemini API.
This node allows users to include text (.txt) and PDF (.pdf) files as input
context for the Gemini model. Files are converted to the appropriate format
required by the API and can be chained together to include multiple files
in a single request.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
"""
For details about the supported file input types, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
input_dir = folder_paths.get_input_directory()
input_files = [
f
for f in os.scandir(input_dir)
if f.is_file()
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
),
},
"optional": {
"GEMINI_INPUT_FILES": (
"GEMINI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
),
},
}
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/Gemini"
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
)
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(
inlineData=GeminiInlineData(
mimeType=mime_type,
data=base64_str,
)
)
def prepare_files(
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
) -> tuple[list[GeminiPart]]:
"""
Loads and formats input files for Gemini API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_file_part(file_path)
files = [input_file_content] + GEMINI_INPUT_FILES
return (files,)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiInputFiles": GeminiInputFiles,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiInputFiles": "Gemini Input Files",
}

View File

@@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC):
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v1"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
@@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC):
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v2"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
@@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC):
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v3"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True

View File

@@ -1,29 +1,86 @@
import io
from typing import TypedDict, Optional
import json
import os
import time
import re
import uuid
from enum import Enum
from inspect import cleandoc
import numpy as np
import torch
from PIL import Image
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
import folder_paths
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties,
Item,
Includable,
OutputContent,
InputImageContent,
Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent,
InputFileContent,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_and_cast_response,
validate_string,
tensor_to_base64_string,
text_filepath_to_data_uri,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
class HistoryEntry(TypedDict):
"""Type definition for a single history entry in the chat."""
prompt: str
response: str
response_id: str
timestamp: float
class ChatHistory(TypedDict):
"""Type definition for the chat history dictionary."""
__annotations__: dict[str, list[HistoryEntry]]
class SupportedOpenAIModel(str, Enum):
o4_mini = "o4-mini"
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
class OpenAIDalle2(ComfyNodeABC):
"""
@@ -115,7 +172,7 @@ class OpenAIDalle2(ComfyNodeABC):
n=1,
size="1024x1024",
unique_id=None,
**kwargs
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-2"
@@ -262,7 +319,7 @@ class OpenAIDalle3(ComfyNodeABC):
quality="standard",
size="1024x1024",
unique_id=None,
**kwargs
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-3"
@@ -400,12 +457,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
n=1,
size="1024x1024",
unique_id=None,
**kwargs
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type="application/json"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
@@ -414,7 +471,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type ="multipart/form-data"
content_type = "multipart/form-data"
batch_size = image.shape[0]
@@ -486,17 +543,466 @@ class OpenAIGPTImage1(ComfyNodeABC):
return (img_tensor,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
class OpenAITextNode(ComfyNodeABC):
"""
Base class for OpenAI text generation nodes.
"""
RETURN_TYPES = (IO.STRING,)
FUNCTION = "api_call"
CATEGORY = "api node/text/OpenAI"
API_NODE = True
class OpenAIChatNode(OpenAITextNode):
"""
Node to generate text responses from an OpenAI model.
"""
def __init__(self) -> None:
"""Initialize the chat node with a new session ID and empty history."""
self.current_session_id: str = str(uuid.uuid4())
self.history: dict[str, list[HistoryEntry]] = {}
self.previous_response_id: Optional[str] = None
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response.",
},
),
"persist_context": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Persist chat context between calls (multi-turn conversation)",
},
),
"model": model_field_to_node_input(
IO.COMBO,
OpenAICreateResponse,
"model",
enum_type=SupportedOpenAIModel,
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"OPENAI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
},
),
"advanced_options": (
"OPENAI_CHAT_CONFIG",
{
"default": None,
"tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses from an OpenAI model."
def get_result_response(
self,
response_id: str,
include: Optional[list[Includable]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
) -> OpenAIResponse:
"""
Retrieve a model response with the given ID from the OpenAI API.
Args:
response_id (str): The ID of the response to retrieve.
include (Optional[List[Includable]]): Additional fields to include
in the response. See the `include` parameter for Response
creation above for more information.
"""
return PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{RESPONSES_ENDPOINT}/{response_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=OpenAIResponse,
query_params={"include": include},
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda response: response.status,
auth_kwargs=auth_kwargs,
).execute()
def get_message_content_from_response(
self, response: OpenAIResponse
) -> list[OutputContent]:
"""Extract message content from the API response."""
for output in response.output:
if output.root.type == "message":
return output.root.content
raise TypeError("No output message found in response")
def get_text_from_message_content(
self, message_content: list[OutputContent]
) -> str:
"""Extract text content from message content."""
for content_item in message_content:
if content_item.root.type == "output_text":
return str(content_item.root.text)
return "No text output found in response"
def get_history_text(self, session_id: str) -> str:
"""Convert the entire history for a given session to JSON string."""
return json.dumps(self.history[session_id])
def display_history_on_node(self, session_id: str, node_id: str) -> None:
"""Display formatted chat history on the node UI."""
render_spec = {
"node_id": node_id,
"component": "ChatHistoryWidget",
"props": {
"history": self.get_history_text(session_id),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
def add_to_history(
self, session_id: str, prompt: str, output_text: str, response_id: str
) -> None:
"""Add a new entry to the chat history."""
if session_id not in self.history:
self.history[session_id] = []
self.history[session_id].append(
{
"prompt": prompt,
"response": output_text,
"response_id": response_id,
"timestamp": time.time(),
}
)
def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
"""Extract text output from the API response."""
message_contents = self.get_message_content_from_response(response)
return self.get_text_from_message_content(message_contents)
def generate_new_session_id(self) -> str:
"""Generate a new unique session ID."""
return str(uuid.uuid4())
def get_session_id(self, persist_context: bool) -> str:
"""Get the current or generate a new session ID based on context persistence."""
return (
self.current_session_id
if persist_context
else self.generate_new_session_id()
)
def tensor_to_input_image_content(
self, image: torch.Tensor, detail_level: Detail = "auto"
) -> InputImageContent:
"""Convert a tensor to an input image content object."""
return InputImageContent(
detail=detail_level,
image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
type="input_image",
)
def create_input_message_contents(
self,
prompt: str,
image: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image."""
content_list: list[InputContent] = [
InputTextContent(text=prompt, type="input_text"),
]
if image is not None:
for i in range(image.shape[0]):
content_list.append(
self.tensor_to_input_image_content(image[i].unsqueeze(0))
)
if files is not None:
content_list.extend(files)
return InputMessageContentList(
root=content_list,
)
def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
"""Extract response ID from prompt if it exists."""
parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
return parsed_id.group(1) if parsed_id else None
def strip_response_tag_from_prompt(self, prompt: str) -> str:
"""Remove the response ID tag from the prompt."""
return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
def delete_history_after_response_id(
self, new_start_id: str, session_id: str
) -> None:
"""Delete history entries after a specific response ID."""
if session_id not in self.history:
return
new_history = []
i = 0
while (
i < len(self.history[session_id])
and self.history[session_id][i]["response_id"] != new_start_id
):
new_history.append(self.history[session_id][i])
i += 1
# Since it's the new starting point (not the response being edited), we include it as well
if i < len(self.history[session_id]):
new_history.append(self.history[session_id][i])
self.history[session_id] = new_history
def api_call(
self,
prompt: str,
persist_context: bool,
model: SupportedOpenAIModel,
unique_id: Optional[str] = None,
images: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
advanced_options: Optional[CreateModelResponseProperties] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
validate_string(prompt, strip_whitespace=False)
session_id = self.get_session_id(persist_context)
response_id_override = self.parse_response_id_from_prompt(prompt)
if response_id_override:
is_starting_from_beginning = response_id_override == "start"
if is_starting_from_beginning:
self.history[session_id] = []
previous_response_id = None
else:
previous_response_id = response_id_override
self.delete_history_after_response_id(response_id_override, session_id)
prompt = self.strip_response_tag_from_prompt(prompt)
elif persist_context:
previous_response_id = self.previous_response_id
else:
previous_response_id = None
# Create response
create_response = SynchronousOperation(
endpoint=ApiEndpoint(
path=RESPONSES_ENDPOINT,
method=HttpMethod.POST,
request_model=OpenAICreateResponse,
response_model=OpenAIResponse,
),
request=OpenAICreateResponse(
input=[
Item(
root=InputMessage(
content=self.create_input_message_contents(
prompt, images, files
),
role="user",
)
),
],
store=True,
stream=False,
model=model,
previous_response_id=previous_response_id,
**(
advanced_options.model_dump(exclude_none=True)
if advanced_options
else {}
),
),
auth_kwargs=kwargs,
).execute()
response_id = create_response.id
# Get result output
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
output_text = self.parse_output_text_from_response(result_response)
# Update history
self.add_to_history(session_id, prompt, output_text, response_id)
self.display_history_on_node(session_id, unique_id)
self.previous_response_id = response_id
return (output_text,)
class OpenAIInputFiles(ComfyNodeABC):
"""
Loads and formats input files for OpenAI API.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
"""
For details about the supported file input types, see:
https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
"""
input_dir = folder_paths.get_input_directory()
input_files = [
f
for f in os.scandir(input_dir)
if f.is_file()
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
and f.stat().st_size < 32 * 1024 * 1024
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
),
},
"optional": {
"OPENAI_INPUT_FILES": (
"OPENAI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
),
},
}
DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
RETURN_TYPES = ("OPENAI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/OpenAI"
def create_input_file_content(self, file_path: str) -> InputFileContent:
return InputFileContent(
file_data=text_filepath_to_data_uri(file_path),
filename=os.path.basename(file_path),
type="input_file",
)
def prepare_files(
self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
) -> tuple[list[InputFileContent]]:
"""
Loads and formats input files for OpenAI API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_input_file_content(file_path)
files = [input_file_content] + OPENAI_INPUT_FILES
return (files,)
class OpenAIChatConfig(ComfyNodeABC):
"""Allows setting additional configuration for the OpenAI Chat Node."""
RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
FUNCTION = "configure"
DESCRIPTION = (
"Allows specifying advanced configuration options for the OpenAI Chat Nodes."
)
CATEGORY = "api node/text/OpenAI"
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"truncation": (
IO.COMBO,
{
"options": ["auto", "disabled"],
"default": "auto",
"tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
},
),
},
"optional": {
"max_output_tokens": model_field_to_node_input(
IO.INT,
OpenAICreateResponse,
"max_output_tokens",
min=16,
default=4096,
max=16384,
tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
),
"instructions": model_field_to_node_input(
IO.STRING, OpenAICreateResponse, "instructions", multiline=True
),
},
}
def configure(
self,
truncation: bool,
instructions: Optional[str] = None,
max_output_tokens: Optional[int] = None,
) -> tuple[CreateModelResponseProperties]:
"""
Configure advanced options for the OpenAI Chat Node.
Note:
While `top_p` and `temperature` are listed as properties in the
spec, they are not supported for all models (e.g., o4-mini).
They are not exposed as inputs at all to avoid having to manually
remove depending on model choice.
"""
return (
CreateModelResponseProperties(
instructions=instructions,
truncation=truncation,
max_output_tokens=max_output_tokens,
),
)
NODE_CLASS_MAPPINGS = {
"OpenAIDalle2": OpenAIDalle2,
"OpenAIDalle3": OpenAIDalle3,
"OpenAIGPTImage1": OpenAIGPTImage1,
"OpenAIChatNode": OpenAIChatNode,
"OpenAIInputFiles": OpenAIInputFiles,
"OpenAIChatConfig": OpenAIChatConfig,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"OpenAIDalle2": "OpenAI DALL·E 2",
"OpenAIDalle3": "OpenAI DALL·E 3",
"OpenAIGPTImage1": "OpenAI GPT Image 1",
"OpenAIChatNode": "OpenAI Chat",
"OpenAIInputFiles": "OpenAI Chat Input Files",
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
}

View File

@@ -6,40 +6,42 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference
from __future__ import annotations
import io
from typing import Optional, TypeVar
import logging
import torch
from typing import Optional, TypeVar
import numpy as np
import torch
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
from comfy_api.input_impl import VideoFromFile
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
from comfy_api_nodes.apis import (
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaGenerateResponse,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaVideoResponse,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
IngredientsMode,
PikaDurationEnum,
PikaResolutionEnum,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaDurationEnum,
Pikaffect,
PikaGenerateResponse,
PikaResolutionEnum,
PikaVideoResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
download_url_to_video_output,
HttpMethod,
PollingOperation,
SynchronousOperation,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
R = TypeVar("R")
@@ -204,6 +206,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -457,7 +460,7 @@ class PikAdditionsNode(PikaNodeBase):
},
}
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what youd like to add to create a seamlessly integrated result."
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
def api_call(
self,

View File

@@ -0,0 +1,462 @@
"""
ComfyUI X Rodin3D(Deemos) API Nodes
Rodin API docs: https://developer.hyper3d.ai/
"""
from __future__ import annotations
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths
import requests
import os
import datetime
import shutil
import time
import io
import logging
import math
from PIL import Image
from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
Rodin3DCheckStatusRequest,
Rodin3DCheckStatusResponse,
Rodin3DDownloadRequest,
Rodin3DDownloadResponse,
JobStatus,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
COMMON_PARAMETERS = {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "18K-Quad"
}
)
}
def create_task_error(response: Rodin3DGenerateResponse):
"""Check if the response has error"""
return hasattr(response, "error")
class Rodin3DAPI:
"""
Generate 3D Assets using Rodin API
"""
RETURN_TYPES = (IO.STRING,)
RETURN_NAMES = ("3D Model Path",)
CATEGORY = "api node/3d/Rodin"
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "api_call"
API_NODE = True
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images == None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality=quality,
mesh_mode=mesh_mode
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return poll_operation.execute()
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return operation.execute()
def GetQualityAndMode(self, PolyCount):
if PolyCount == "200K-Triangle":
mesh_mode = "Raw"
quality = "medium"
else:
mesh_mode = "Quad"
if PolyCount == "4K-Quad":
quality = "extra-low"
elif PolyCount == "8K-Quad":
quality = "low"
elif PolyCount == "18K-Quad":
quality = "medium"
elif PolyCount == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality
def DownLoadFiles(self, Url_List):
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(Save_path, exist_ok=True)
model_file_path = None
for Item in Url_List.list:
url = Item.url
file_name = Item.name
file_path = os.path.join(Save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
time.sleep(2)
else:
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
return model_file_path
class Rodin3D_Regular(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Regular"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Detail"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Smooth"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Sketch(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
**kwargs
):
tier = "Sketch"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
material_type = "PBR"
quality = "medium"
mesh_mode = "Quad"
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
}

View File

@@ -0,0 +1,635 @@
"""Runway API Nodes
API Docs:
- https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
User Guides:
- https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
- https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
- https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
- https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
"""
from typing import Union, Optional, Any
from enum import Enum
import torch
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse,
RunwayTaskStatusEnum as TaskStatus,
RunwayModelEnum as Model,
RunwayDurationEnum as Duration,
RunwayAspectRatioEnum as AspectRatio,
RunwayPromptImageObject,
RunwayPromptImageDetailedObject,
RunwayTextToImageRequest,
RunwayTextToImageResponse,
Model4,
ReferenceImage,
RunwayTextToImageAspectRatioEnum,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_video_output,
image_tensor_pair_to_batch,
validate_string,
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
AVERAGE_DURATION_I2V_SECONDS = 64
AVERAGE_DURATION_FLF_SECONDS = 256
AVERAGE_DURATION_T2I_SECONDS = 41
class RunwayApiError(Exception):
"""Base exception for Runway API errors."""
pass
class RunwayGen4TurboAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
field_1280_720 = "1280:720"
field_720_1280 = "720:1280"
field_1104_832 = "1104:832"
field_832_1104 = "832:1104"
field_960_960 = "960:960"
field_1584_672 = "1584:672"
class RunwayGen3aAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
field_768_1280 = "768:1280"
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
return response.output[0]
return None
# TODO: replace with updated image validation utils (upstream)
def validate_input_image(image: torch.Tensor) -> bool:
"""
Validate the input image is within the size limits for the Runway API.
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
"""
return image.shape[2] < 8000 and image.shape[1] < 8000
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
],
failed_statuses=[
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: (response.status.value),
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
).execute()
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
return response.output[0]
return None
class RunwayVideoGenNode(ComfyNodeABC):
"""Runway Video Node Base."""
RETURN_TYPES = ("VIDEO",)
FUNCTION = "api_call"
CATEGORY = "api node/video/Runway"
API_NODE = True
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no video data found in response."
)
return True
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
initial_response = initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (download_url_to_video_output(video_url),)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen3a Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen4 Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen4TurboAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
"""Runway First-Last Frame Node."""
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"end_frame": (
IO.IMAGE,
{
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"unique_id": "UNIQUE_ID",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_input_image(end_frame)
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayTextToImageNode(ComfyNodeABC):
"""Runway Text to Image Node."""
RETURN_TYPES = ("IMAGE",)
FUNCTION = "api_call"
CATEGORY = "api node/image/Runway"
API_NODE = True
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayTextToImageRequest,
"ratio",
enum_type=RunwayTextToImageAspectRatioEnum,
),
},
"optional": {
"reference_image": (
IO.IMAGE,
{"tooltip": "Optional reference image to guide the generation"},
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: TaskStatusResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no image data found in response."
)
return True
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
node_id=node_id,
)
def api_call(
self,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# Validate inputs
validate_string(prompt, min_length=1)
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_input_image(reference_image)
download_urls = upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload reference image to comfy api.")
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
# Create request
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
)
# Execute initial request
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
method=HttpMethod.POST,
request_model=RunwayTextToImageRequest,
response_model=RunwayTextToImageResponse,
),
request=request,
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
self.validate_response(final_response)
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (download_url_to_image_tensor(image_url),)
NODE_CLASS_MAPPINGS = {
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
"RunwayTextToImageNode": RunwayTextToImageNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
"RunwayTextToImageNode": "Runway Text to Image",
}

View File

@@ -0,0 +1,574 @@
import os
from folder_paths import get_output_directory
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis import (
TripoOrientation,
TripoModelVersion,
)
from comfy_api_nodes.apis.tripo_api import (
TripoTaskType,
TripoStyle,
TripoFileReference,
TripoFileEmptyReference,
TripoUrlReference,
TripoTaskResponse,
TripoTaskStatus,
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoConvertModelRequest,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_bytesio,
)
def upload_image_to_tripo(image, **kwargs):
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
def get_model_url_from_response(response: TripoTaskResponse) -> str:
if response.data is not None:
for key in ["pbr_model", "model", "base_model"]:
if getattr(response.data.output, key, None) is not None:
return getattr(response.data.output, key)
raise RuntimeError(f"Failed to get model url from response: {response}")
def poll_until_finished(
kwargs: dict[str, str],
response: TripoTaskResponse,
) -> tuple[str, str]:
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
if response.code != 0:
raise RuntimeError(f"Failed to generate mesh: {response.error}")
task_id = response.data.task_id
response_poll = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TripoTaskResponse,
),
completed_statuses=[TripoTaskStatus.SUCCESS],
failed_statuses=[
TripoTaskStatus.FAILED,
TripoTaskStatus.CANCELLED,
TripoTaskStatus.UNKNOWN,
TripoTaskStatus.BANNED,
TripoTaskStatus.EXPIRED,
],
status_extractor=lambda x: x.data.status,
auth_kwargs=kwargs,
node_id=kwargs["unique_id"],
result_url_extractor=get_model_url_from_response,
progress_extractor=lambda x: x.data.progress,
).execute()
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
bytesio = download_url_to_bytesio(url)
# Save the downloaded model file
model_file = f"tripo_model_{task_id}.glb"
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
f.write(bytesio.getvalue())
return model_file, task_id
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
class TripoTextToModelNode:
"""
Generates 3D models synchronously based on a text prompt using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": ("STRING", {"multiline": True}),
},
"optional": {
"negative_prompt": ("STRING", {"multiline": True}),
"model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"image_seed": ("INT", {"default": 42}),
"model_seed": ("INT", {"default": 42}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if not prompt:
raise RuntimeError("Prompt is required")
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoTextToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoTextToModelRequest(
type=TripoTaskType.TEXT_TO_MODEL,
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
image_seed=image_seed,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoImageToModelNode:
"""
Generates 3D models synchronously based on a single image using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"model_seed": ("INT", {"default": 42}),
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if image is None:
raise RuntimeError("Image is required")
tripo_file = upload_image_to_tripo(image, **kwargs)
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoImageToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoImageToModelRequest(
type=TripoTaskType.IMAGE_TO_MODEL,
file=tripo_file,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoMultiviewToModelNode:
"""
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"image_left": ("IMAGE",),
"image_back": ("IMAGE",),
"image_right": ("IMAGE",),
"model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"model_seed": ("INT", {"default": 42}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
if image is None:
raise RuntimeError("front image for multiview is required")
images = []
image_dict = {
"image": image,
"image_left": image_left,
"image_back": image_back,
"image_right": image_right
}
if image_left is None and image_back is None and image_right is None:
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
for image_name in ["image", "image_left", "image_back", "image_right"]:
image_ = image_dict[image_name]
if image_ is not None:
tripo_file = upload_image_to_tripo(image_, **kwargs)
images.append(tripo_file)
else:
images.append(TripoFileEmptyReference())
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoMultiviewToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoMultiviewToModelRequest(
type=TripoTaskType.MULTIVIEW_TO_MODEL,
files=images,
model_version=model_version,
orientation=orientation,
texture=texture,
pbr=pbr,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoTextureNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_task_id": ("MODEL_TASK_ID",),
},
"optional": {
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 80
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoTextureModelRequest,
response_model=TripoTaskResponse,
),
request=TripoTextureModelRequest(
original_model_task_id=model_task_id,
texture=texture,
pbr=pbr,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRefineNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_task_id": ("MODEL_TASK_ID", {
"tooltip": "Must be a v1.4 Tripo model"
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 240
def generate_mesh(self, model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoRefineModelRequest,
response_model=TripoTaskResponse,
),
request=TripoRefineModelRequest(
draft_model_task_id=model_task_id
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRigNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("MODEL_TASK_ID",),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "RIG_TASK_ID")
RETURN_NAMES = ("model_file", "rig task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 180
def generate_mesh(self, original_model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoAnimateRigRequest,
response_model=TripoTaskResponse,
),
request=TripoAnimateRigRequest(
original_model_task_id=original_model_task_id,
out_format="glb",
spec="tripo"
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRetargetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("RIG_TASK_ID",),
"animation": ([
"preset:idle",
"preset:walk",
"preset:climb",
"preset:jump",
"preset:slash",
"preset:shoot",
"preset:hurt",
"preset:fall",
"preset:turn",
],),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
RETURN_NAMES = ("model_file", "retarget task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, animation, original_model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoAnimateRetargetRequest,
response_model=TripoTaskResponse,
),
request=TripoAnimateRetargetRequest(
original_model_task_id=original_model_task_id,
animation=animation,
out_format="glb",
bake_animation=True
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoConversionNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
"format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
},
"optional": {
"quad": ("BOOLEAN", {"default": False}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
"texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
def VALIDATE_INPUTS(cls, input_types):
# The min and max of input1 and input2 are still validated because
# we didn't take `input1` or `input2` as arguments
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
return True
RETURN_TYPES = ()
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoConvertModelRequest,
response_model=TripoTaskResponse,
),
request=TripoConvertModelRequest(
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
face_limit=face_limit if face_limit != -1 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
NODE_CLASS_MAPPINGS = {
"TripoTextToModelNode": TripoTextToModelNode,
"TripoImageToModelNode": TripoImageToModelNode,
"TripoMultiviewToModelNode": TripoMultiviewToModelNode,
"TripoTextureNode": TripoTextureNode,
"TripoRefineNode": TripoRefineNode,
"TripoRigNode": TripoRigNode,
"TripoRetargetNode": TripoRetargetNode,
"TripoConversionNode": TripoConversionNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TripoTextToModelNode": "Tripo: Text to Model",
"TripoImageToModelNode": "Tripo: Image to Model",
"TripoMultiviewToModelNode": "Tripo: Multiview to Model",
"TripoTextureNode": "Tripo: Texture model",
"TripoRefineNode": "Tripo: Refine Draft model",
"TripoRigNode": "Tripo: Rig model",
"TripoRetargetNode": "Tripo: Retarget rigged model",
"TripoConversionNode": "Tripo: Convert model",
}

View File

@@ -0,0 +1,152 @@
import os
from pathlib import Path
from typing import Optional
from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource
from comfy_config.types import (
ComfyConfig,
ProjectConfig,
PyProjectConfig,
PyProjectSettings
)
def validate_and_extract_os_classifiers(classifiers: list) -> list:
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
if not os_classifiers:
return []
os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}
for os_value in os_values:
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
return []
return os_values
def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
if not accelerator_classifiers:
return []
accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]
valid_accelerators = {
"GPU :: NVIDIA CUDA",
"GPU :: AMD ROCm",
"GPU :: Intel Arc",
"NPU :: Huawei Ascend",
"GPU :: Apple Metal",
}
for accelerator_value in accelerator_values:
if accelerator_value not in valid_accelerators:
return []
return accelerator_values
"""
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
This function reads and parses the pyproject.toml file in the specified directory
to extract project and ComfyUI-specific configuration information. If no
pyproject.toml file is found, it creates a minimal configuration using the
folder name as the project name. If a Python file is provided, it uses the
file name (without extension) as the project name.
Args:
path (str): Path to the directory containing the pyproject.toml file, or
path to a .py file. If pyproject.toml doesn't exist in a directory,
the folder name will be used as the default project name. If a .py
file is provided, the filename (without .py extension) will be used
as the project name.
Returns:
Optional[PyProjectConfig]: A PyProjectConfig object containing:
- project: Basic project information (name, version, dependencies, etc.)
- tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.)
Returns None if configuration extraction fails or if the provided file
is not a Python file.
Notes:
- If pyproject.toml is missing in a directory, creates a default config with folder name
- If a .py file is provided, creates a default config with filename (without extension)
- Returns None for non-Python files
Example:
>>> from comfy_config import config_parser
>>> # For directory
>>> custom_node_dir = os.path.dirname(os.path.realpath(__file__))
>>> project_config = config_parser.extract_node_configuration(custom_node_dir)
>>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml
>>>
>>> # For single-file Python node file
>>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py"
>>> project_config = config_parser.extract_node_configuration(py_file_path)
>>> print(project_config.project.name) # "my_node"
"""
def extract_node_configuration(path) -> Optional[PyProjectConfig]:
if os.path.isfile(path):
file_path = Path(path)
if file_path.suffix.lower() != '.py':
return None
project_name = file_path.stem
project = ProjectConfig(name=project_name)
comfy = ComfyConfig()
return PyProjectConfig(project=project, tool_comfy=comfy)
folder_name = os.path.basename(path)
toml_path = Path(path) / "pyproject.toml"
if not toml_path.exists():
project = ProjectConfig(name=folder_name)
comfy = ComfyConfig()
return PyProjectConfig(project=project, tool_comfy=comfy)
raw_settings = load_pyproject_settings(toml_path)
project_data = raw_settings.project
tool_data = raw_settings.tool
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
dependencies = project_data.get("dependencies", [])
supported_comfyui_frontend_version = ""
for dep in dependencies:
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
break
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
classifiers = project_data.get('classifiers', [])
supported_os = validate_and_extract_os_classifiers(classifiers)
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
project_data['supported_os'] = supported_os
project_data['supported_accelerators'] = supported_accelerators
project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version
project_data['supported_comfyui_version'] = supported_comfyui_version
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
def load_pyproject_settings(toml_path: Path) -> PyProjectSettings:
class PyProjectLoader(PyProjectSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls,
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
):
return (TomlConfigSettingsSource(settings_cls, toml_path),)
return PyProjectLoader()

97
comfy_config/types.py Normal file
View File

@@ -0,0 +1,97 @@
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Optional
# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes
# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py.
# Any changes to one must be reflected in the other to maintain consistency.
class NodeVersion(BaseModel):
changelog: str
dependencies: List[str]
deprecated: bool
id: str
version: str
download_url: str
class Node(BaseModel):
id: str
name: str
description: str
author: Optional[str] = None
license: Optional[str] = None
icon: Optional[str] = None
repository: Optional[str] = None
tags: List[str] = Field(default_factory=list)
latest_version: Optional[NodeVersion] = None
class PublishNodeVersionResponse(BaseModel):
node_version: NodeVersion
signedUrl: str
class URLs(BaseModel):
homepage: str = Field(default="", alias="Homepage")
documentation: str = Field(default="", alias="Documentation")
repository: str = Field(default="", alias="Repository")
issues: str = Field(default="", alias="Issues")
class Model(BaseModel):
location: str
model_url: str
class ComfyConfig(BaseModel):
publisher_id: str = Field(default="", alias="PublisherId")
display_name: str = Field(default="", alias="DisplayName")
icon: str = Field(default="", alias="Icon")
models: List[Model] = Field(default_factory=list, alias="Models")
includes: List[str] = Field(default_factory=list)
web: Optional[str] = None
banner_url: str = ""
class License(BaseModel):
file: str = ""
text: str = ""
class ProjectConfig(BaseModel):
name: str = ""
description: str = ""
version: str = "1.0.0"
requires_python: str = Field(default=">= 3.9", alias="requires-python")
dependencies: List[str] = Field(default_factory=list)
license: License = Field(default_factory=License)
urls: URLs = Field(default_factory=URLs)
supported_os: List[str] = Field(default_factory=list)
supported_accelerators: List[str] = Field(default_factory=list)
supported_comfyui_version: str = ""
supported_comfyui_frontend_version: str = ""
@field_validator('license', mode='before')
@classmethod
def validate_license(cls, v):
if isinstance(v, str):
return License(text=v)
elif isinstance(v, dict):
return License(**v)
elif isinstance(v, License):
return v
else:
return License()
class PyProjectConfig(BaseModel):
project: ProjectConfig = Field(default_factory=ProjectConfig)
tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig)
class PyProjectSettings(BaseSettings):
project: dict = Field(default_factory=dict)
tool: dict = Field(default_factory=dict)
model_config = SettingsConfigDict(extra='allow')

View File

@@ -2,6 +2,7 @@ import nodes
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
class EmptyCosmosLatentVideo:
@@ -75,8 +76,53 @@ class CosmosImageToVideoLatent:
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
class CosmosPredict2ImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
if end_image is not None:
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
latent_format = comfy.latent_formats.Wan21()
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
}

View File

@@ -0,0 +1,26 @@
import node_helpers
class ReferenceLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
},
"optional": {"latent": ("LATENT", ),}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return (conditioning, )
NODE_CLASS_MAPPINGS = {
"ReferenceLatent": ReferenceLatent,
}

View File

@@ -1,4 +1,5 @@
import node_helpers
import comfy.utils
class CLIPTextEncodeFlux:
@classmethod
@@ -56,8 +57,52 @@ class FluxDisableGuidance:
return (c, )
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
class FluxKontextImageScale:
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE", ),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "scale"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
}

View File

@@ -14,8 +14,10 @@ import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator
from comfy.comfy_types import FileLocator, IO
from server import PromptServer
MAX_RESOLUTION = nodes.MAX_RESOLUTION
@@ -229,6 +231,246 @@ class SVG:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch:
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
if image2 is None:
return (image1,)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
max_batch = max(image1.shape[0], image2.shape[0])
if image1.shape[0] < max_batch:
image1 = torch.cat(
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
)
if image2.shape[0] < max_batch:
image2 = torch.cat(
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
)
# Match image sizes if requested
if match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
aspect_ratio = w2 / h2
if direction in ["left", "right"]:
target_h, target_w = h1, int(h1 * aspect_ratio)
else: # up, down
target_w, target_h = w1, int(w1 / aspect_ratio)
image2 = comfy.utils.common_upscale(
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
pad_value = 0.0
if not isinstance(color_val, tuple):
pad_value = color_val
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
if h1 != h2:
target_h = max(h1, h2)
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
target_w = max(w1, w2)
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat(
[
image1,
torch.ones(
*image1.shape[:-1],
max_channels - image1.shape[-1],
device=image1.device,
),
],
dim=-1,
)
if image2.shape[-1] < max_channels:
image2 = torch.cat(
[
image2,
torch.ones(
*image2.shape[:-1],
max_channels - image2.shape[-1],
device=image2.device,
),
],
dim=-1,
)
# Add spacing if specified
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
max(image1.shape[1], image2.shape[1]),
spacing_width,
image1.shape[-1],
)
else:
spacing_shape = (
image1.shape[0],
spacing_width,
max(image1.shape[2], image2.shape[2]),
image1.shape[-1],
)
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
if isinstance(color_val, tuple):
for i, c in enumerate(color_val):
if i < spacing.shape[-1]:
spacing[..., i] = c
if spacing.shape[-1] == 4: # Add alpha
spacing[..., 3] = 1.0
else:
spacing[..., : min(3, spacing.shape[-1])] = color_val
if spacing.shape[-1] == 4:
spacing[..., 3] = 1.0
# Concatenate images
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
if spacing_width > 0:
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class ResizeAndPadImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"target_width": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"target_height": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "resize_and_pad"
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width
scale_h = target_height / orig_height
scale = min(scale_w, scale_h)
new_width = int(orig_width * scale)
new_height = int(orig_height * scale)
image_permuted = image.permute(0, 3, 1, 2)
resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
pad_value = 0.0 if padding_color == "black" else 1.0
padded = torch.full(
(batch_size, channels, target_height, target_width),
pad_value,
dtype=image.dtype,
device=image.device
)
y_offset = (target_height - new_height) // 2
x_offset = (target_width - new_width) // 2
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1)
return (output,)
class SaveSVGNode:
"""
Save SVG files on disk.
@@ -310,6 +552,37 @@ class SaveSVGNode:
counter += 1
return { "ui": { "images": results } }
class GetImageSize:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
RETURN_NAMES = ("width", "height", "batch_size")
FUNCTION = "get_size"
CATEGORY = "image"
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
def get_size(self, image, unique_id=None) -> tuple[int, int]:
height = image.shape[1]
width = image.shape[2]
batch_size = image.shape[0]
# Send progress text to display size on the node
if unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
return width, height, batch_size
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
@@ -318,4 +591,7 @@ NODE_CLASS_MAPPINGS = {
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
"ResizeAndPadImage": ResizeAndPadImage,
"GetImageSize": GetImageSize,
}

View File

@@ -16,7 +16,7 @@ class Load3D():
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),

View File

@@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
@@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM:
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
latent_format = None
sigma_data = 1.0
if sampling == "eps":
@@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM:
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
elif sampling == "cosmos_rflow":
sampling_type = comfy.model_sampling.COSMOS_RFLOW
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)

View File

@@ -268,6 +268,52 @@ class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -281,4 +327,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
}

View File

@@ -296,6 +296,41 @@ class RegexExtract():
return result,
class RegexReplace():
DESCRIPTION = "Find and replace text using regex patterns."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True}),
},
"optional": {
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
return result,
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
@@ -306,7 +341,8 @@ NODE_CLASS_MAPPINGS = {
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract
"RegexExtract": RegexExtract,
"RegexReplace": RegexReplace,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -319,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract"
"RegexExtract": "Regex Extract",
"RegexReplace": "Regex Replace",
}

View File

@@ -1,4 +1,5 @@
import torch
from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel:
@classmethod
@@ -14,7 +15,7 @@ class TorchCompileModel:
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
set_torch_compile_wrapper(model=m, backend=backend)
return (m, )
NODE_CLASS_MAPPINGS = {

709
comfy_extras/nodes_train.py Normal file
View File

@@ -0,0 +1,709 @@
import datetime
import json
import logging
import os
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint
import tqdm
import comfy.samplers
import comfy.sd
import comfy.utils
import comfy.model_management
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
self.optimizer.zero_grad()
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
latent = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
False
)
# Ensure model is in training mode and computing gradients
# x0 pred
denoised = model_wrap(noise, sigmas, **extra_args)
try:
loss = self.loss_fn(denoised, latent.clone())
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
self.optimizer.step()
# torch.cuda.memory._dump_snapshot("trainn.pickle")
# torch.cuda.memory._record_memory_history(enabled=None)
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None"):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
w, h = None, None
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
class LoadImageSetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": (
[
f
for f in os.listdir(folder_paths.get_input_directory())
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
],
{"image_upload": True, "allow_batch": True},
)
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
INPUT_IS_LIST = True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
@classmethod
def VALIDATE_INPUTS(s, images, resize_method):
filenames = images[0] if isinstance(images[0], list) else images
for image in filenames:
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
def load_images(self, input_files, resize_method):
input_dir = folder_paths.get_input_directory()
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
image_files = [
f
for f in input_files
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
return (output_tensor,)
class LoadImageSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
def load_images(self, folder, resize_method):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
return (output_tensor,)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
fwd, args, kwargs, use_reentrant=False
)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
class TrainLoraNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
"latents": (
"LATENT",
{
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
},
),
"positive": (
IO.CONDITIONING,
{"tooltip": "The positive conditioning to use for training."},
),
"batch_size": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 10000,
"step": 1,
"tooltip": "The batch size to use for training.",
},
),
"steps": (
IO.INT,
{
"default": 16,
"min": 1,
"max": 100000,
"tooltip": "The number of steps to train the LoRA for.",
},
),
"learning_rate": (
IO.FLOAT,
{
"default": 0.0005,
"min": 0.0000001,
"max": 1.0,
"step": 0.000001,
"tooltip": "The learning rate to use for training.",
},
),
"rank": (
IO.INT,
{
"default": 8,
"min": 1,
"max": 128,
"tooltip": "The rank of the LoRA layers.",
},
),
"optimizer": (
["AdamW", "Adam", "SGD", "RMSprop"],
{
"default": "AdamW",
"tooltip": "The optimizer to use for training.",
},
),
"loss_function": (
["MSE", "L1", "Huber", "SmoothL1"],
{
"default": "MSE",
"tooltip": "The loss function to use for training.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
},
),
"training_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for training."},
),
"lora_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."},
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
"default": "[None]",
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
},
),
},
}
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
FUNCTION = "train"
CATEGORY = "training"
EXPERIMENTAL = True
def train(
self,
model,
latents,
positive,
batch_size,
steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
existing_lora,
):
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(
f"{key}.dora_scale", None
)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None
adapter_cls = adapters[0]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
# Setup sampler and guider like in test script
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
pbar.set_postfix({"loss": f"{loss:.4f}"})
train_sampler = TrainSampler(
criterion, optimizer, loss_callback=loss_callback
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
# yoland: this currently resize to the first image in the dataset
# Training loop
torch.cuda.empty_cache()
try:
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
# Generate random sigma
sigma = mp.model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
sigma = torch.tensor([sigma])
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
indices = torch.randperm(num_images)[:batch_size]
ss.sample(
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
)
finally:
for m in mp.model.modules():
unpatch(m)
del ss, train_sampler, optimizer
torch.cuda.empty_cache()
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return (mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora_model"
CATEGORY = "loaders"
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
EXPERIMENTAL = True
def load_lora_model(self, model, lora, strength_model):
if strength_model == 0:
return (model, )
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
return (model_lora, )
class SaveLoRA:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (
IO.LORA_MODEL,
{
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
},
),
"prefix": (
"STRING",
{
"default": "loras/ComfyUI_trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
},
"optional": {
"steps": (
IO.INT,
{
"forceInput": True,
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
},
),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "loaders"
EXPERIMENTAL = True
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return {}
class LossGraphNode:
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"loss": (IO.LOSS_MAP, {"default": {}}),
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "plot_loss"
OUTPUT_NODE = True
CATEGORY = "training"
EXPERIMENTAL = True
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img.save(
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
pnginfo=metadata,
)
return {
"ui": {
"images": [
{
"filename": f"{filename_prefix}_{date}.png",
"subfolder": "",
"type": "temp",
}
]
}
}
NODE_CLASS_MAPPINGS = {
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoraModelLoader": LoraModelLoader,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA",
"SaveLoRANode": "Save LoRA Weights",
"LoraModelLoader": "Load LoRA Model",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}

View File

@@ -268,8 +268,9 @@ class WanVaceToVideo:
trim_latent = reference_image.shape[2]
mask = mask.unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
@@ -344,6 +345,44 @@ class WanCameraImageToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanPhantomSubjectToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"images": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond2 = negative
if images is not None:
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
latent_images = []
for i in images:
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
concat_latent_image = torch.cat(latent_images, dim=2)
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
out_latent = {}
out_latent["samples"] = latent
return (positive, cond2, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
@@ -352,4 +391,5 @@ NODE_CLASS_MAPPINGS = {
"WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
}

View File

@@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage):
def load_capture(self, image, **kwargs):
return super().load_image(folder_paths.get_annotated_filepath(image))
@classmethod
def IS_CHANGED(cls, image, width, height, capture_on_queue):
return super().IS_CHANGED(image)
NODE_CLASS_MAPPINGS = {
"WebcamCapture": WebcamCapture,

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.34"
__version__ = "0.3.43"

View File

@@ -1,23 +1,35 @@
import sys
import copy
import logging
import threading
import heapq
import inspect
import logging
import sys
import threading
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
import nodes
from comfy_execution.caching import (
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
HierarchicalCache,
LRUCache,
)
from comfy_execution.graph import (
DynamicPrompt,
ExecutionBlocker,
ExecutionList,
get_input_info,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
@@ -417,17 +429,20 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
logging.error(f"!!! Exception during processing !!! {ex}")
logging.error(traceback.format_exc())
tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
error_details = {
"node_id": real_node_id,
"exception_message": str(ex),
"exception_message": "{}\n{}".format(ex, tips),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (ExecutionResult.FAILURE, error_details, ex)
@@ -909,7 +924,6 @@ class PromptQueue:
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self
def put(self, item):
with self.mutex:
@@ -954,6 +968,7 @@ class PromptQueue:
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
# Note: slow
def get_current_queue(self):
with self.mutex:
out = []
@@ -961,6 +976,13 @@ class PromptQueue:
out += [x]
return (out, copy.deepcopy(self.queue))
# read-safe as long as queue items are immutable
def get_current_queue_volatile(self):
with self.mutex:
running = [x for x in self.currently_running.values()]
queued = copy.copy(self.queue)
return (running, queued)
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

View File

@@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
"""
Get the full path of a file in a folder, has to be a file
"""
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
@@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix
def get_input_subfolders() -> list[str]:
"""Returns a list of all subfolder paths in the input directory, recursively.
Returns:
List of folder paths relative to the input directory, excluding the root directory
"""
input_dir = get_input_directory()
folders = []
try:
if not os.path.exists(input_dir):
return []
for root, dirs, _ in os.walk(input_dir):
rel_path = os.path.relpath(root, input_dir)
if rel_path != ".": # Only include non-root directories
# Normalize path separators to forward slashes
folders.append(rel_path.replace(os.sep, '/'))
return sorted(folders)
except FileNotFoundError:
return []

25
main.py
View File

@@ -17,7 +17,6 @@ if __name__ == "__main__":
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
def apply_custom_paths():
@@ -186,7 +185,13 @@ def prompt_worker(q, server_instance):
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
# Log Time in a more readable way after 10 minutes
if execution_time > 600:
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
logging.info(f"Prompt executed in {execution_time}")
else:
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
@@ -238,6 +243,15 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
def setup_database():
try:
from app.database.db import init_db, dependencies_available
if dependencies_available():
init_db()
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
@@ -260,18 +274,18 @@ def start_comfyui(asyncio_loop=None):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
hook_breaker_ac10a0.restore_functions()
cuda_malloc_warning()
setup_database()
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
if args.quick_test_for_ci:
exit(0)
@@ -301,6 +315,9 @@ if __name__ == "__main__":
logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
event_loop, _, start_all_func = start_comfyui()
try:
x = start_all_func()

View File

@@ -5,12 +5,18 @@ from comfy.cli_args import args
from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}):
def conditioning_set_values(conditioning, values={}, append=False):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
val = values[k]
if append:
old_val = n[1].get(k, None)
if old_val is not None:
val = old_val + val
n[1][k] = val
c.append(n)
return c

View File

@@ -920,7 +920,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -930,7 +930,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@@ -1103,16 +1103,7 @@ class unCLIPConditioning:
if strength == 0:
return (conditioning, )
c = []
for t in conditioning:
o = t[1].copy()
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
if "unclip_conditioning" in o:
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
else:
o["unclip_conditioning"] = [x]
n = [t[0], o]
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True)
return (c, )
class GLIGENLoader:
@@ -2070,11 +2061,13 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
"ImageCrop": "Image Crop",
"ImageStitch": "Image Stitch",
"ImageBlend": "Image Blend",
"ImageBlur": "Image Blur",
"ImageQuantize": "Image Quantize",
"ImageSharpen": "Image Sharpen",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size",
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
@@ -2132,6 +2125,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
try:
from comfy_config import config_parser
project_config = config_parser.extract_node_configuration(module_path)
web_dir_name = project_config.tool_comfy.web
if web_dir_name:
web_dir_path = os.path.join(module_path, web_dir_name)
if os.path.isdir(web_dir_path):
project_name = project_config.project.name
EXTENSION_WEB_DIRS[project_name] = web_dir_path
logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name))
except Exception as e:
logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}")
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
if os.path.isdir(web_dir):
@@ -2219,6 +2231,7 @@ def init_builtin_extra_nodes():
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
"nodes_train.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",
@@ -2266,6 +2279,7 @@ def init_builtin_extra_nodes():
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
]
import_failed = []
@@ -2290,6 +2304,10 @@ def init_builtin_api_nodes():
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py",
"nodes_tripo.py",
"nodes_rodin.py",
"nodes_gemini.py",
]
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.34"
version = "0.3.43"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,12 +1,13 @@
comfyui-frontend-package==1.19.9
comfyui-workflow-templates==0.1.14
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.30
comfyui-embedded-docs==0.2.3
torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
einops
transformers>=4.28.1
transformers>=4.37.2
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
@@ -17,6 +18,8 @@ Pillow
scipy
tqdm
psutil
alembic
SQLAlchemy
#non essential dependencies:
kornia>=0.7.1
@@ -24,3 +27,4 @@ spandrel
soundfile
av>=14.2.0
pydantic~=2.0
pydantic-settings~=2.0

View File

@@ -29,6 +29,7 @@ import comfy.model_management
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
@@ -159,7 +160,7 @@ class PromptServer():
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.prompt_queue = execution.PromptQueue(self)
self.loop = loop
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
@@ -226,7 +227,7 @@ class PromptServer():
return response
@routes.get("/embeddings")
def get_embeddings(self):
def get_embeddings(request):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@@ -282,7 +283,6 @@ class PromptServer():
a.update(f.read())
b.update(image.file.read())
image.file.seek(0)
f.close()
return a.hexdigest() == b.hexdigest()
return False
@@ -390,7 +390,7 @@ class PromptServer():
async def view_image(request):
if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
filename, output_dir = folder_paths.annotated_filepath(filename)
if not filename:
return web.Response(status=400)
@@ -476,9 +476,8 @@ class PromptServer():
# Get content type from mimetype, defaulting to 'application/octet-stream'
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
# For security, force certain extensions to download instead of display
file_extension = os.path.splitext(filename)[1].lower()
if file_extension in {'.html', '.htm', '.js', '.css'}:
# For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
content_type = 'application/octet-stream' # Forces download
return web.FileResponse(
@@ -621,7 +620,7 @@ class PromptServer():
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@@ -746,6 +745,13 @@ class PromptServer():
web.static('/templates', workflow_templates_path)
])
# Serve embedded documentation from the package
embedded_docs_path = FrontendManager.embedded_docs_path()
if embedded_docs_path:
self.app.add_routes([
web.static('/docs', embedded_docs_path)
])
self.app.add_routes([
web.static('/', self.web_root),
])
@@ -782,7 +788,7 @@ class PromptServer():
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1

View File

View File

@@ -0,0 +1,243 @@
import torch
from unittest.mock import patch, MagicMock
# Mock nodes module to prevent CUDA initialization during import
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
# Mock server module for PromptServer
mock_server = MagicMock()
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
from comfy_extras.nodes_images import ImageStitch
class TestImageStitch:
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
"""Helper to create test images with specific dimensions"""
return torch.rand(batch_size, height, width, channels)
def test_no_image2_passthrough(self):
"""Test that when image2 is None, image1 is returned unchanged"""
node = ImageStitch()
image1 = self.create_test_image()
result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1
assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self):
"""Test basic horizontal stitching to the right"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "right", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
def test_basic_horizontal_stitch_left(self):
"""Test basic horizontal stitching to the left"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "left", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
def test_basic_vertical_stitch_down(self):
"""Test basic vertical stitching downward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "down", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
def test_basic_vertical_stitch_up(self):
"""Test basic vertical stitching upward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "up", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
def test_size_matching_horizontal(self):
"""Test size matching for horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
result = node.stitch(image1, "right", True, 0, "white", image2)
# image2 should be resized to match image1's height (64) with preserved aspect ratio
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, 64, expected_width, 3)
def test_size_matching_vertical(self):
"""Test size matching for vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32)
result = node.stitch(image1, "down", True, 0, "white", image2)
# image2 should be resized to match image1's width (64) with preserved aspect ratio
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, expected_height, 64, 3)
def test_padding_for_mismatched_heights_horizontal(self):
"""Test padding when heights don't match in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=32)
image2 = self.create_test_image(height=48, width=24) # Shorter height
result = node.stitch(image1, "right", False, 0, "white", image2)
# Both images should be padded to height 64
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
def test_padding_for_mismatched_widths_vertical(self):
"""Test padding when widths don't match in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=64)
image2 = self.create_test_image(height=24, width=48) # Narrower width
result = node.stitch(image1, "down", False, 0, "white", image2)
# Both images should be padded to width 64
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
def test_spacing_horizontal(self):
"""Test spacing addition in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
spacing_width = 16
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
# Expected width: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 32, 72, 3)
def test_spacing_vertical(self):
"""Test spacing addition in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
spacing_width = 16
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
# Expected height: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 72, 32, 3)
def test_spacing_color_values(self):
"""Test that spacing colors are applied correctly"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Test white spacing
result_white = node.stitch(image1, "right", False, 16, "white", image2)
# Check that spacing region contains white values (close to 1.0)
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
assert torch.all(spacing_region >= 0.9) # Should be close to white
# Test black spacing
result_black = node.stitch(image1, "right", False, 16, "black", image2)
spacing_region = result_black[0][:, :, 32:48, :]
assert torch.all(spacing_region <= 0.1) # Should be close to black
def test_odd_spacing_width_made_even(self):
"""Test that odd spacing widths are made even"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Use odd spacing width
result = node.stitch(image1, "right", False, 15, "white", image2)
# Should be made even (16), so total width = 32 + 16 + 32 = 80
assert result[0].shape == (1, 32, 80, 3)
def test_batch_size_matching(self):
"""Test that different batch sizes are handled correctly"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=32, width=32)
image2 = self.create_test_image(batch_size=1, height=32, width=32)
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should match larger batch size
assert result[0].shape == (2, 32, 64, 3)
def test_channel_matching_rgb_to_rgba(self):
"""Test that channel differences are handled (RGB + alpha)"""
node = ImageStitch()
image1 = self.create_test_image(channels=3) # RGB
image2 = self.create_test_image(channels=4) # RGBA
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_channel_matching_rgba_to_rgb(self):
"""Test that channel differences are handled (RGBA + RGB)"""
node = ImageStitch()
image1 = self.create_test_image(channels=4) # RGBA
image2 = self.create_test_image(channels=3) # RGB
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_all_color_options(self):
"""Test all available color options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
colors = ["white", "black", "red", "green", "blue"]
for color in colors:
result = node.stitch(image1, "right", False, 16, color, image2)
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
def test_all_directions(self):
"""Test all direction options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
directions = ["right", "left", "up", "down"]
for direction in directions:
result = node.stitch(image1, direction, False, 0, "white", image2)
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
def test_batch_size_channel_spacing_integration(self):
"""Test integration of batch matching, channel matching, size matching, and spacings"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
result = node.stitch(image1, "right", True, 8, "red", image2)
# Should handle: batch matching, size matching, channel matching, spacing
assert result[0].shape[0] == 2 # Batch size matched
assert result[0].shape[-1] == 4 # Channels matched to max
assert result[0].shape[1] == 64 # Height from image1 (size matching)
# Width should be: 48 + 8 (spacing) + resized_image2_width
expected_image2_width = int(64 * (32/32)) # Resized to height 64
expected_total_width = 48 + 8 + expected_image2_width
assert result[0].shape[2] == expected_total_width

View File

@@ -0,0 +1,51 @@
import pytest
import os
import tempfile
from folder_paths import get_input_subfolders, set_input_directory
@pytest.fixture(scope="module")
def mock_folder_structure():
with tempfile.TemporaryDirectory() as temp_dir:
# Create a nested folder structure
folders = [
"folder1",
"folder1/subfolder1",
"folder1/subfolder2",
"folder2",
"folder2/deep",
"folder2/deep/nested",
"empty_folder"
]
# Create the folders
for folder in folders:
os.makedirs(os.path.join(temp_dir, folder))
# Add some files to test they're not included
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
f.write("test")
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
f.write("test")
set_input_directory(temp_dir)
yield temp_dir
def test_gets_all_folders(mock_folder_structure):
folders = get_input_subfolders()
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
assert sorted(folders) == sorted(expected)
def test_handles_nonexistent_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
nonexistent = os.path.join(temp_dir, "nonexistent")
set_input_directory(nonexistent)
assert get_input_subfolders() == []
def test_empty_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
set_input_directory(temp_dir)
assert get_input_subfolders() == [] # Empty since we don't include root

18
utils/install_util.py Normal file
View File

@@ -0,0 +1,18 @@
from pathlib import Path
import sys
# The path to the requirements.txt file
requirements_path = Path(__file__).parents[1] / "requirements.txt"
def get_missing_requirements_message():
"""The warning message to display when a package is missing."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {requirements_path}
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
""".strip()