Compare commits

...

178 Commits

Author SHA1 Message Date
Jedrzej Kosinski
48deb15c0e Simplify multigpu dispatch: run all devices on pool threads (#13340)
Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090
with SD1.5 and NetaYume models. No meaningful performance difference
(within noise). All-pool is simpler: eliminates the main_device
special case, main_batch_tuple deferred execution, and the 3-way
branch in the dispatch loop.
2026-04-09 01:15:57 -07:00
Jedrzej Kosinski
4b93c4360f Implement persistent thread pool for multi-GPU CFG splitting (#13329)
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a
persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device()
once at startup, preserving compiled kernel caches across diffusion steps.

- Add MultiGPUThreadPool class in comfy/multigpu.py
- Create pool in CFGGuider.outer_sample(), shut down in finally block
- Main thread handles its own device batch directly for zero overhead
- Falls back to sequential execution if no pool is available
2026-04-08 05:39:07 -07:00
Jedrzej Kosinski
da3864436c Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2026-04-08 05:08:38 -07:00
huemin
b615af1c65 Add support for small flux.2 decoder (#13314) 2026-04-07 03:44:18 -04:00
comfyanonymous
40862c0776 Support Ace Step 1.5 XL model. (#13317) 2026-04-07 03:13:47 -04:00
Terry Jia
50076f3439 format blueprint (#13315)
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-04-06 23:33:55 -04:00
comfyanonymous
61c2387436 Ace step empty latent nodes follow intermediate dtype. (#13313) 2026-04-06 18:12:16 -07:00
Terry Jia
7083484a48 image histogram node (#13153)
* image histogram node

* update color curve blueprint using image histogram node

---------

Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-04-06 14:54:02 -07:00
comfyanonymous
4b1444fc7a Update README.md with new frontend release cycle. (#13301) 2026-04-05 16:37:27 -07:00
Daxiong (Lin)
8cbbea8f6a chore: update workflow templates to v0.9.44 (#13290) 2026-04-05 13:31:11 +08:00
comfyanonymous
13917b3880 Nightly Nvidia pytorch is now cu132 (#13288) 2026-04-04 16:02:47 -07:00
comfyanonymous
f21f6b2212 Add portable release for intel XPU. (#13272) 2026-04-03 15:29:06 -04:00
Daxiong (Lin)
eb0686bbb6 Update template to 0.9.43 (#13265) 2026-04-02 23:52:10 -07:00
Alexander Piskun
5de94e70ec feat(api-nodes): new Partner nodes for Wan2.7 (#13264)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-04-02 23:51:47 -07:00
comfyanonymous
76b75f3ad7 Fix some issue with insecure browsers. (#13261)
If you are on a recent chromium or chrome based browser this doesn't affect you.

This is to give time for the lazy firefox devs to implement PNA.
2026-04-02 16:39:34 -04:00
comfyanonymous
0c63b4f6e3 Remove dead code. (#13251) 2026-04-01 20:22:06 -04:00
Daxiong (Lin)
7d437687c2 chore: update workflow templates to v0.9.41 (#13242) 2026-03-31 20:23:25 -07:00
comfyanonymous
e2ddf28d78 Fix some fp8 scaled checkpoints no longer working. (#13239) 2026-03-31 14:27:17 -07:00
comfyanonymous
076639fed9 Update README with note on model support (#13235)
Added note about additional supported models in ComfyUI.
2026-03-30 23:11:02 -04:00
Jedrzej Kosinski
b418fb1582 Fix device mismatch: update LoadedModel.device when _switch_parent swaps to parent patcher
When a multigpu clone ModelPatcher is garbage collected, LoadedModel._switch_parent
switches the weakref to point at the parent (main) ModelPatcher. However, it was not
updating LoadedModel.device, leaving it with the old clone's device (e.g., cuda:1).
On subsequent runs, this stale device was passed to ModelPatcherDynamic.load(), causing
an assertion failure (device_to != self.load_device).

Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:59:38 -07:00
Jedrzej Kosinski
20803749c3 Add detailed multigpu debug logging to load_models_gpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:53:36 -07:00
Jedrzej Kosinski
3fab720be9 Add debug logging for device mismatch in ModelPatcherDynamic.load
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:45:55 -07:00
Jedrzej Kosinski
afdddcee66 Re-enable comfy-kitchen cuda backend for multigpu testing
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:32:52 -07:00
Jedrzej Kosinski
1d8e379f41 Rename MultiGPU Work Units to MultiGPU CFG Split
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:00:20 -07:00
Jedrzej Kosinski
5f4fcd19e7 Simplify multigpu nodes: default max_gpus=2, remove gpu_options input, disable Options node
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:30:32 -07:00
Jedrzej Kosinski
d52dcbc88f Rewrite multigpu nodes to V3 format
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:23:13 -07:00
Jedrzej Kosinski
84f465e791 Set CUDA device at start of multigpu threads to avoid multithreading bugs
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:07:54 -07:00
Jedrzej Kosinski
be35378986 Merge branch 'master' into worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>

# Conflicts:
#	comfy/samplers.py
2026-03-30 06:24:55 -07:00
Christian Byrne
55e6478526 Rename utils/string nodes with Text prefix and add search aliases (#13227)
Rename all 11 nodes in the utils/string category to include a "Text"
prefix for better discoverability and natural sorting. Regex nodes get
user-friendly names without "Regex" in the display name.

Renames:
- Concatenate → Text Concatenate
- Substring → Text Substring
- Length → Text Length
- Case Converter → Text Case Converter
- Trim → Text Trim
- Replace → Text Replace
- Contains → Text Contains
- Compare → Text Compare
- Regex Match → Text Match
- Regex Extract → Text Extract Substring
- Regex Replace → Text Replace (Regex)

All renamed nodes include their old display name as a search alias so
users can still find them by searching the original name. Regex nodes
also include "regex" as a search alias.
2026-03-29 21:02:44 -07:00
comfyanonymous
537c10d231 Update README.md with latest AMD Linux pytorch. (#13228) 2026-03-29 19:07:38 -07:00
rattus
8d723d2caa Fix/tweak pinned memory accounting (#13221)
* mm: Lower windows pin threshold

Some workflows have more extranous use of shared GPU memory than is
accounted for in the 5% pin headroom. Lower this for safety.

* mm: Remove pin count clearing threshold.

TOTAL_PINNED_MEMORY is shared between the legacy and aimdo pinning
systems, however this catch-all assumes only the legacy system exists.
Remove the catch-all as the PINNED_MEMORY buffer is coherent already.
2026-03-29 16:43:24 -07:00
Alexander Piskun
d113d1cc32 feat(api-nodes-Tencent3D): allow smaller possible face_count; add uv_image output (#13207)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-29 14:11:30 -07:00
Jukka Seppänen
a500f1edac CORE-13 feat: Support RT-DETRv4 detection model (#12748) 2026-03-28 23:34:10 -04:00
comfyanonymous
3f77450ef1 Fix #13214 (#13216) 2026-03-28 22:35:59 -04:00
Terry Jia
fc1fdf3389 fix: avoid nested sampler function calls in Color Curves shader (#13209) 2026-03-28 13:13:05 -04:00
rattus
b353a7c863 Integrate RAM cache with model RAM management (#13173) 2026-03-27 21:34:16 -04:00
Terry Jia
3696c5bad6 Add has_intermediate_output flag for nodes with interactive UI (#13048) 2026-03-27 21:06:38 -04:00
comfyanonymous
3a56201da5 Allow flux conditioning without a pooled output. (#13198) 2026-03-27 20:36:26 -04:00
Alexander Piskun
6a2cdb817d fix(api-nodes-nanobana): raise error when not output image is present (#13167)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-27 12:11:41 -07:00
ComfyUI Wiki
85b7495135 chore: update workflow templates to v0.9.39 (#13196) 2026-03-27 10:13:02 -07:00
Jin Yi
225c52f6a4 fix: register image/svg+xml MIME type for .svg files (#13186)
The /view endpoint returns text/plain for .svg files on some platforms
because Python's mimetypes module does not always include SVG by default.
Explicitly register image/svg+xml so <img> tags can render SVGs correctly.

Amp-Thread-ID: https://ampcode.com/threads/T-019d2da7-6a64-726a-af91-bd9c44e7f43c
2026-03-26 22:13:29 -07:00
comfyanonymous
b1fdbeb9a7 Fix blur and sharpen nodes not working with fp16 intermediates. (#13181) 2026-03-26 22:18:16 -04:00
Terry Jia
1dc64f3526 feat: add curve inputs and raise uniform limit for GLSL shader node (#13158)
* feat: add curve inputs and raise uniform limit for GLSL shader node

* allow arbitrary size for curve
2026-03-26 21:45:05 -04:00
ComfyUI Wiki
359559c913 chore: update workflow templates to v0.9.38 (#13176) 2026-03-26 12:07:38 -07:00
Alexander Piskun
8165485a17 feat(api-nodes): added new Topaz model (#13175)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-26 12:02:04 -07:00
Jukka Seppänen
b0fd65e884 fix: regression in text generate with LTXAV model (#13170) 2026-03-26 09:55:05 -07:00
comfyanonymous
2a1f402601 Make Qwen 8B work with TextGenerate node. (#13160) 2026-03-25 23:21:44 -04:00
Luke Mino-Altherr
3eba2dcf2d fix(assets): recognize temp directory in asset category resolution (#13159) 2026-03-25 19:59:59 -07:00
Jukka Seppänen
404d7b9978 feat: Support Qwen3.5 text generation models (#12771) 2026-03-25 22:48:28 -04:00
Dante
6580a6bc01 fix(number-convert): preserve int precision for large numbers (#13147) 2026-03-25 18:06:34 -04:00
Dr.Lt.Data
3b15651bc6 bump manager version to 4.1 (#13156) 2026-03-25 16:49:29 -04:00
Alexander Piskun
a55835f10c fix(api-nodes): made Reve node price badges more precise (#13154)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-25 11:05:49 -07:00
Krishna Chaitanya
b53b10ea61 Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (#13145)
When training_dtype is set to "none" and the model's native dtype is
float16, GradScaler was unconditionally enabled. However, GradScaler
does not support bfloat16 gradients (only float16/float32), causing a
NotImplementedError when lora_dtype is "bf16" (the default).

Fix by only enabling GradScaler when LoRA parameters are not in
bfloat16, since bfloat16 has the same exponent range as float32 and
does not need gradient scaling to avoid underflow.

Fixes #13124
2026-03-24 23:53:44 -04:00
Luke Mino-Altherr
7d5534d8e5 feat(assets): register output files as assets after prompt execution (#12812) 2026-03-24 20:48:55 -07:00
Kohaku-Blueleaf
5ebb0c2e0b FP8 bwd training (#13121) 2026-03-24 20:39:04 -04:00
Dante
a0a64c679f Add Number Convert node (#13041)
* Add Number Convert node for unified numeric type conversion

Consolidates fragmented IntToFloat/FloatToInt nodes (previously only
available via third-party packs like ComfyMath, FillNodes, etc.) into
a single core node.

- Single input accepting INT, FLOAT, STRING, and BOOL types
- Two outputs: FLOAT and INT
- Conversion: bool→0/1, string→parsed number, float↔int standard cast
- Follows Math Expression node patterns (comfy_api, io.Schema, etc.)

Refs: COM-16925

* Register nodes_number_convert.py in extras_files list

Without this entry in nodes.py, the Number Convert node file
would not be discovered and loaded at startup.

* Add isfinite guard, exception chaining, and unit tests for Number Convert node

- Add math.isfinite() check to prevent int() crash on inf/nan string inputs
- Use 'from None' for cleaner exception chaining on string parse failure
- Add 21 unit tests covering all input types and error paths
2026-03-24 15:38:08 -07:00
Terry Jia
8e73678dae CURVE node (#12757)
* CURVE node

* remove curve to sigmas node

* feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986)

CurveInput is an abstract base class so future curve representations
(bezier, LUT-based, analytical functions) can be added without breaking
downstream nodes that type-check against CurveInput.

MonotoneCubicCurve is the concrete implementation that:
- Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly
- Pre-computes slopes as numpy arrays at construction time
- Provides vectorised interp_array() using numpy for batch evaluation
- interp() for single-value evaluation
- to_lut() for generating lookup tables

CurveEditor node wraps raw widget points in MonotoneCubicCurve.

* linear curve

* refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema

* feat: add HISTOGRAM type and histogram support to CurveEditor

* code improve

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-03-24 17:47:28 -04:00
comfyanonymous
c2862b24af Update templates package version. (#13141) 2026-03-24 17:36:12 -04:00
Alexander Piskun
f9ec85f739 feat(api-nodes): update xAI Grok nodes (#13140) 2026-03-24 13:27:39 -07:00
Kelly Yang
2d5fd3f5dd fix: set default values of Color Adjustment node to zero (#13084)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-24 14:22:30 -04:00
comfyanonymous
2d4970ff67 Update frontend version to 1.42.8 (#13126) 2026-03-23 20:43:41 -04:00
Jukka Seppänen
e87858e974 feat: LTX2: Support reference audio (ID-LoRA) (#13111) 2026-03-23 18:22:24 -04:00
Dr.Lt.Data
da6edb5a4e bump manager version to 4.1b8 (#13108) 2026-03-23 12:59:21 -04:00
comfyanonymous
6265a239f3 Add warning for users who disable dynamic vram. (#13113) 2026-03-22 18:46:18 -04:00
Talmaj
d49420b3c7 LongCat-Image edit (#13003) 2026-03-21 23:51:05 -04:00
Jedrzej Kosinski
f410d28b33 Merge origin/master into worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d009d-e059-7623-85ca-401168168516
Co-authored-by: Amp <amp@ampcode.com>
2026-03-18 04:21:30 -07:00
Jedrzej Kosinski
f4b99bc623 Made multigpu deepclone load model from disk to avoid needing to deepclone actual model object, fixed issues with merge, turn off cuda backend as it causes device mismatch issue with rope (and potentially other ops), will investigate 2026-02-17 04:55:00 -08:00
Jedrzej Kosinski
df2fd4c869 Merge branch 'master' into worksplit-multigpu 2026-02-17 02:53:06 -08:00
Jedrzej Kosinski
4661d1db5a Bring patches changes from _calc_cond_batch into _calc_cond_batch_multigpu 2025-10-15 17:34:36 -07:00
Jedrzej Kosinski
b326a544d5 Merge branch 'master' into worksplit-multigpu 2025-10-15 17:33:02 -07:00
Jedrzej Kosinski
d89dd5f0b0 Satisfy ruff 2025-10-13 22:00:34 -07:00
Jedrzej Kosinski
8cbbf0be6c Merge branch 'master' into worksplit-multigpu 2025-10-13 21:53:14 -07:00
Jedrzej Kosinski
c2115a4bac Merge branch 'master' into worksplit-multigpu 2025-09-24 23:45:26 -07:00
Jedrzej Kosinski
bb44c2ecb9 Merge branch 'master' into worksplit-multigpu 2025-09-18 14:20:27 -07:00
Jedrzej Kosinski
efcd8280d6 Merge branch 'master' into worksplit-multigpu 2025-09-11 20:59:47 -07:00
Jedrzej Kosinski
9e9c129cd0 Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2025-08-29 23:36:19 -07:00
Jedrzej Kosinski
ac14ee68c0 Merge branch 'master' into worksplit-multigpu 2025-08-18 19:51:24 -07:00
Jedrzej Kosinski
2c8f485434 Merge branch 'master' into worksplit-multigpu 2025-08-18 00:29:52 -07:00
Jedrzej Kosinski
383f9b34cb Merge branch 'master' into worksplit-multigpu 2025-08-17 16:02:44 -07:00
Jedrzej Kosinski
b0741c7e5b Merge branch 'master' into worksplit-multigpu 2025-08-15 16:50:04 -07:00
Jedrzej Kosinski
1489399cb5 Merge branch 'master' into worksplit-multigpu 2025-08-13 19:47:08 -07:00
Jedrzej Kosinski
3677943fa5 Merge branch 'master' into worksplit-multigpu 2025-08-13 14:06:09 -07:00
Jedrzej Kosinski
cfb63bfcd7 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-08-11 14:09:58 -07:00
Jedrzej Kosinski
962c3c832c Merge branch 'master' into worksplit-multigpu 2025-08-11 14:09:41 -07:00
Jedrzej Kosinski
6ea69369ce Merge branch 'master' into worksplit-multigpu 2025-08-07 23:24:02 -07:00
Jedrzej Kosinski
b4f559b34d Merge branch 'master' into worksplit-multigpu 2025-08-04 20:23:19 -07:00
Jedrzej Kosinski
df122a7dba Merge branch 'master' into worksplit-multigpu 2025-08-01 12:31:57 -07:00
Jedrzej Kosinski
67e906aa64 Merge branch 'master' into worksplit-multigpu 2025-07-31 04:00:22 -07:00
Jedrzej Kosinski
382f84a826 Merge branch 'master' into worksplit-multigpu 2025-07-29 17:17:29 -07:00
Jedrzej Kosinski
9cca36fa2b Merge branch 'master' into worksplit-multigpu 2025-07-29 12:47:36 -07:00
Jedrzej Kosinski
5d5024296d Merge branch 'master' into worksplit-multigpu 2025-07-28 06:17:24 -07:00
Jedrzej Kosinski
3b90a30178 Merge branch 'master' into worksplit-multigpu-wip 2025-07-27 01:03:25 -07:00
Jedrzej Kosinski
3c4104652b Merge branch 'master' into worksplit-multigpu-wip 2025-07-22 11:42:23 -07:00
kosinkadink1@gmail.com
9855baaab3 Merge branch 'master' into worksplit-multigpu 2025-07-09 03:57:30 -05:00
Jedrzej Kosinski
d53479a197 Merge branch 'master' into worksplit-multigpu 2025-07-01 17:33:05 -05:00
Jedrzej Kosinski
443a795850 Merge branch 'master' into worksplit-multigpu 2025-06-24 00:49:24 -05:00
Jedrzej Kosinski
431dec8e53 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-06-24 00:48:58 -05:00
Jedrzej Kosinski
44e053c26d Improve error handling for multigpu threads 2025-06-24 00:48:51 -05:00
Jedrzej Kosinski
1ae98932f1 Merge branch 'master' into worksplit-multigpu 2025-06-17 04:58:56 -05:00
kosinkadink1@gmail.com
0336b0ace8 Merge branch 'master' into worksplit-multigpu 2025-06-01 02:39:26 -07:00
kosinkadink1@gmail.com
8ae25235ec Merge branch 'master' into worksplit-multigpu 2025-05-21 12:01:27 -07:00
Jedrzej Kosinski
9726eac475 Merge branch 'master' into worksplit-multigpu 2025-05-12 19:29:13 -05:00
Jedrzej Kosinski
272e8d42c1 Merge branch 'master' into worksplit-multigpu 2025-04-22 22:40:00 -05:00
Jedrzej Kosinski
6211d2be5a Merge branch 'master' into worksplit-multigpu 2025-04-19 17:36:23 -05:00
Jedrzej Kosinski
8be711715c Make unload_all_models account for all devices 2025-04-19 17:35:54 -05:00
Jedrzej Kosinski
b5cccf1325 Merge branch 'master' into worksplit-multigpu 2025-04-18 15:39:34 -05:00
Jedrzej Kosinski
2a54a904f4 Merge branch 'master' into worksplit-multigpu 2025-04-16 19:26:48 -05:00
Jedrzej Kosinski
ed6f92c975 Merge branch 'master' into worksplit-multigpu 2025-04-16 16:53:57 -05:00
Jedrzej Kosinski
adc66c0698 Merge branch 'master' into worksplit-multigpu 2025-04-16 14:23:56 -05:00
Jedrzej Kosinski
ccd5c01e5a Merge branch 'master' into worksplit-multigpu 2025-04-09 09:17:12 -05:00
Jedrzej Kosinski
2fa9affcc1 Merge branch 'master' into worksplit-multigpu 2025-04-08 22:52:17 -05:00
Jedrzej Kosinski
407a5a656f Rollback core of last commit due to weird behavior 2025-03-28 02:48:11 -05:00
kosinkadink1@gmail.com
9ce9ff8ef8 Allow chained MultiGPU Work Unit nodes to affect max_gpus present on ModelPatcher clone 2025-03-28 15:29:44 +08:00
Jedrzej Kosinski
63567c0ce8 Merge branch 'master' into worksplit-multigpu 2025-03-27 22:36:46 -05:00
Jedrzej Kosinski
a786ce5ead Merge branch 'master' into worksplit-multigpu 2025-03-26 22:26:26 -05:00
Jedrzej Kosinski
4879b47648 Merge branch 'master' into worksplit-multigpu 2025-03-18 22:19:32 -05:00
Jedrzej Kosinski
5ccec33c22 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-03-17 14:27:39 -05:00
Jedrzej Kosinski
219d3cd0d0 Merge branch 'master' into worksplit-multigpu 2025-03-17 14:26:35 -05:00
Jedrzej Kosinski
c4ba399475 Merge branch 'master' into worksplit-multigpu 2025-03-15 09:12:09 -05:00
Jedrzej Kosinski
cc928a786d Merge branch 'master' into worksplit-multigpu 2025-03-13 20:59:11 -05:00
Jedrzej Kosinski
6e144b98c4 Merge branch 'master' into worksplit-multigpu 2025-03-09 00:00:38 -06:00
Jedrzej Kosinski
6dca17bd2d Satisfy ruff linting 2025-03-03 23:08:29 -06:00
Jedrzej Kosinski
5080105c23 Merge branch 'master' into worksplit-multigpu 2025-03-03 22:56:53 -06:00
Jedrzej Kosinski
093914a247 Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code 2025-03-03 22:56:13 -06:00
Jedrzej Kosinski
605893d3cf Merge branch 'master' into worksplit-multigpu 2025-02-24 19:23:16 -06:00
Jedrzej Kosinski
048f4f0b3a Merge branch 'master' into worksplit-multigpu 2025-02-17 19:35:58 -06:00
Jedrzej Kosinski
d2504fb701 Merge branch 'master' into worksplit-multigpu 2025-02-11 22:34:51 -06:00
Jedrzej Kosinski
b03763bca6 Merge branch 'multigpu_support' into worksplit-multigpu 2025-02-07 13:27:49 -06:00
Jedrzej Kosinski
476aa79b64 Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups 2025-02-06 08:44:07 -06:00
Jedrzej Kosinski
441cfd1a7a Merge branch 'master' into multigpu_support 2025-02-06 08:10:48 -06:00
Jedrzej Kosinski
99a5c1068a Merge branch 'master' into multigpu_support 2025-02-02 03:19:18 -06:00
Jedrzej Kosinski
02747cde7d Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu 2025-01-29 11:10:23 -06:00
Jedrzej Kosinski
0b3233b4e2 Merge remote-tracking branch 'origin/master' into multigpu_support 2025-01-28 06:11:07 -06:00
Jedrzej Kosinski
eda866bf51 Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet 2025-01-27 06:25:48 -06:00
Jedrzej Kosinski
e3298b84de Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support 2025-01-26 09:34:20 -06:00
Jedrzej Kosinski
c7feef9060 Cast transformer_options for multigpu 2025-01-26 05:29:27 -06:00
Jedrzej Kosinski
51af7fa1b4 Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets 2025-01-25 06:05:01 -06:00
Jedrzej Kosinski
46969c380a Initial MultiGPU support for controlnets 2025-01-24 05:39:38 -06:00
Jedrzej Kosinski
5db4277449 Make sure additional_models are unloaded as well when perform 2025-01-23 19:06:05 -06:00
Jedrzej Kosinski
02a4d0ad7d Added unload_model_and_clones to model_management.py to allow unloading only relevant models 2025-01-23 01:20:00 -06:00
Jedrzej Kosinski
ef137ac0b6 Merge branch 'multigpu_support' of https://github.com/kosinkadink/ComfyUI into multigpu_support 2025-01-20 04:34:39 -06:00
Jedrzej Kosinski
328d4f16a9 Make WeightHooks compatible with MultiGPU, clean up some code 2025-01-20 04:34:26 -06:00
Jedrzej Kosinski
bdbcb85b8d Merge branch 'multigpu_support' of https://github.com/Kosinkadink/ComfyUI into multigpu_support 2025-01-20 00:51:42 -06:00
Jedrzej Kosinski
6c9e94bae7 Merge branch 'master' into multigpu_support 2025-01-20 00:51:37 -06:00
Jedrzej Kosinski
bfce723311 Initial work on multigpu_clone function, which will account for additional_models getting cloned 2025-01-17 03:31:28 -06:00
Jedrzej Kosinski
31f5458938 Merge branch 'master' into multigpu_support 2025-01-16 18:25:05 -06:00
Jedrzej Kosinski
2145a202eb Merge branch 'master' into multigpu_support 2025-01-15 19:58:28 -06:00
Jedrzej Kosinski
25818dc848 Added a 'max_gpus' input 2025-01-14 13:45:14 -06:00
Jedrzej Kosinski
198953cd08 Add nodes_multigpu.py to loaded nodes 2025-01-14 12:24:55 -06:00
Jedrzej Kosinski
ec16ee2f39 Merge branch 'master' into multigpu_support 2025-01-13 20:21:06 -06:00
Jedrzej Kosinski
d5088072fb Make test node for multigpu instead of storing it in just a local __init__.py 2025-01-13 20:20:25 -06:00
Jedrzej Kosinski
8d4b50158e Merge branch 'master' into multigpu_support 2025-01-11 20:16:42 -06:00
Jedrzej Kosinski
e88c6c03ff Fix cond_cat to not try to cast anything that doesn't have a 'to' function 2025-01-10 23:05:24 -06:00
Jedrzej Kosinski
d3cf2b7b24 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-10 20:24:37 -06:00
Jedrzej Kosinski
7448f02b7c Initial proof of concept of giving splitting cond sampling between multiple GPUs 2025-01-08 03:33:05 -06:00
Jedrzej Kosinski
871258aa72 Add get_all_torch_devices to get detected devices intended for current torch hardware device 2025-01-07 21:06:03 -06:00
Jedrzej Kosinski
66838ebd39 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-07 20:11:27 -06:00
Jedrzej Kosinski
7333281698 Clean up a typehint 2025-01-07 02:58:59 -06:00
Jedrzej Kosinski
3cd4c5cb0a Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) 2025-01-07 02:22:49 -06:00
Jedrzej Kosinski
11c6d56037 Merge branch 'master' into hooks_part2 2025-01-07 01:01:53 -06:00
Jedrzej Kosinski
216fea15ee Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable 2025-01-07 00:59:18 -06:00
Jedrzej Kosinski
58bf8815c8 Add a get_injections function to ModelPatcher 2025-01-06 20:34:30 -06:00
Jedrzej Kosinski
1b38f5bf57 removed 4 whitespace lines to satisfy Ruff, 2025-01-06 17:11:12 -06:00
Jedrzej Kosinski
2724ac4a60 Merge branch 'master' into hooks_part2 2025-01-06 17:04:24 -06:00
Jedrzej Kosinski
f48f90e471 Make hook_scope functional for TransformerOptionsHook 2025-01-06 02:23:04 -06:00
Jedrzej Kosinski
6463c39ce0 Merge branch 'master' into hooks_part2 2025-01-06 01:28:26 -06:00
Jedrzej Kosinski
0a7e2ae787 Filter only registered hooks on self.conds in CFGGuider.sample 2025-01-06 01:04:29 -06:00
Jedrzej Kosinski
03a97b604a Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) 2025-01-06 01:03:59 -06:00
Jedrzej Kosinski
4446c86052 Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational 2025-01-05 22:25:51 -06:00
Jedrzej Kosinski
8270ff312f Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time 2025-01-05 21:07:02 -06:00
Jedrzej Kosinski
db2d7ad9ba Merge branch 'add_sample_sigmas' into hooks_part2 2025-01-05 15:45:13 -06:00
Jedrzej Kosinski
6620d86318 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch 2025-01-05 15:26:22 -06:00
Jedrzej Kosinski
111fd0cadf Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type 2025-01-04 02:04:07 -06:00
Jedrzej Kosinski
776aa734e1 Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) 2025-01-04 01:02:21 -06:00
Jedrzej Kosinski
5a2ad032cb Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed 2025-01-03 20:02:27 -06:00
Jedrzej Kosinski
d44295ef71 Merge branch 'master' into hooks_part2 2025-01-03 18:28:31 -06:00
Jedrzej Kosinski
bf21be066f Merge branch 'master' into hooks_part2 2024-12-30 14:16:22 -06:00
Jedrzej Kosinski
72bbf49349 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' 2024-12-29 15:49:09 -06:00
123 changed files with 550979 additions and 446 deletions

View File

@@ -20,29 +20,12 @@ jobs:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu130"
python_minor: "13"
python_patch: "11"
python_patch: "12"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
@@ -76,3 +59,20 @@ jobs:
rel_extra_name: ""
test_release: false
secrets: inherit
release_xpu:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release Intel XPU"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "xpu"
python_minor: "13"
python_patch: "12"
rel_name: "intel"
rel_extra_name: ""
test_release: true
secrets: inherit

View File

@@ -61,6 +61,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- NOTE: There are many more models supported than the list below, if you want to see what is supported see our templates list inside ComfyUI.
- Image Models
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
@@ -136,7 +137,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
- Builds a new release using the latest stable core version
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository
- Every 2+ weeks frontend updates are merged into the core repository
- Features are frozen for the upcoming core release
- Development continues for the next release cycle
@@ -232,7 +233,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2```
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
@@ -275,7 +276,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu132```
#### Troubleshooting

View File

@@ -1,6 +1,7 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
create_stub_asset,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
count_active_siblings,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
@@ -80,6 +82,8 @@ __all__ = [
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"count_active_siblings",
"create_stub_asset",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"convert_metadata_to_rows",

View File

@@ -78,6 +78,18 @@ def upsert_asset(
return asset, created, updated
def create_stub_asset(
session: Session,
size_bytes: int,
mime_type: str | None = None,
) -> Asset:
"""Create a new asset with no hash (stub for later enrichment)."""
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
session.add(asset)
session.flush()
return asset
def bulk_insert_assets(
session: Session,
rows: list[dict],

View File

@@ -114,6 +114,23 @@ def get_reference_by_file_path(
)
def count_active_siblings(
session: Session,
asset_id: str,
exclude_reference_id: str,
) -> int:
"""Count active (non-deleted) references to an asset, excluding one reference."""
return (
session.query(AssetReference)
.filter(
AssetReference.asset_id == asset_id,
AssetReference.id != exclude_reference_id,
AssetReference.deleted_at.is_(None),
)
.count()
)
def reference_exists_for_asset_id(
session: Session,
asset_id: str,

View File

@@ -13,6 +13,7 @@ from app.assets.database.queries import (
delete_references_by_ids,
ensure_tags_exist,
get_asset_by_hash,
get_reference_by_id,
get_references_for_prefixes,
get_unenriched_references,
mark_references_missing_outside_prefixes,
@@ -338,6 +339,7 @@ def build_asset_specs(
"metadata": metadata,
"hash": asset_hash,
"mime_type": mime_type,
"job_id": None,
}
)
tag_pool.update(tags)
@@ -426,6 +428,7 @@ def enrich_asset(
except OSError:
return new_level
initial_mtime_ns = get_mtime_ns(stat_p)
rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None
metadata = None
@@ -489,6 +492,18 @@ def enrich_asset(
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
# Optimistic guard: if the reference's mtime_ns changed since we
# started (e.g. ingest_existing_file updated it), our results are
# stale — discard them to avoid overwriting fresh registration data.
ref = get_reference_by_id(session, reference_id)
if ref is None or ref.mtime_ns != initial_mtime_ns:
session.rollback()
logging.info(
"Ref %s mtime changed during enrichment, discarding stale result",
reference_id,
)
return ENRICHMENT_STUB
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
set_reference_system_metadata(session, reference_id, system_metadata)

View File

@@ -77,7 +77,9 @@ class _AssetSeeder:
"""
def __init__(self) -> None:
self._lock = threading.Lock()
# RLock is required because _run_scan() drains pending work while
# holding _lock and re-enters start() which also acquires _lock.
self._lock = threading.RLock()
self._state = State.IDLE
self._progress: Progress | None = None
self._last_progress: Progress | None = None
@@ -92,6 +94,7 @@ class _AssetSeeder:
self._prune_first: bool = False
self._progress_callback: ProgressCallback | None = None
self._disabled: bool = False
self._pending_enrich: dict | None = None
def disable(self) -> None:
"""Disable the asset seeder, preventing any scans from starting."""
@@ -196,6 +199,42 @@ class _AssetSeeder:
compute_hashes=compute_hashes,
)
def enqueue_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan now, or queue it for after the current scan.
If the seeder is idle, starts immediately. Otherwise, the enrich
request is stored and will run automatically when the current scan
finishes.
Args:
roots: Tuple of root types to scan
compute_hashes: If True, compute blake3 hashes
Returns:
True if started immediately, False if queued for later
"""
with self._lock:
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
return True
if self._pending_enrich is not None:
existing_roots = set(self._pending_enrich["roots"])
existing_roots.update(roots)
self._pending_enrich["roots"] = tuple(existing_roots)
self._pending_enrich["compute_hashes"] = (
self._pending_enrich["compute_hashes"] or compute_hashes
)
else:
self._pending_enrich = {
"roots": roots,
"compute_hashes": compute_hashes,
}
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
return False
def cancel(self) -> bool:
"""Request cancellation of the current scan.
@@ -381,9 +420,13 @@ class _AssetSeeder:
return marked
finally:
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
self._reset_to_idle()
def _reset_to_idle(self) -> None:
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
@@ -594,9 +637,18 @@ class _AssetSeeder:
},
)
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
self._reset_to_idle()
pending = self._pending_enrich
if pending is not None:
self._pending_enrich = None
if not self.start_enrich(
roots=pending["roots"],
compute_hashes=pending["compute_hashes"],
):
logging.warning(
"Pending enrich scan could not start (roots=%s)",
pending["roots"],
)
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records.

View File

@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
ingest_existing_file,
register_output_files,
upload_from_temp_path,
)
from app.assets.database.queries import (
@@ -72,6 +74,8 @@ __all__ = [
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"ingest_existing_file",
"register_output_files",
"get_mtime_ns",
"get_size_and_mtime_ns",
"list_assets_page",

View File

@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
metadata: ExtractedMetadata | None
hash: str | None
mime_type: str | None
job_id: str | None
class AssetRow(TypedDict):
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
name: str
preview_id: str | None
user_metadata: dict[str, Any] | None
job_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
"name": spec["info_name"],
"preview_id": None,
"user_metadata": user_metadata,
"job_id": spec.get("job_id"),
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,

View File

@@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.queries import (
add_tags_to_reference,
count_active_siblings,
create_stub_asset,
ensure_tags_exist,
fetch_reference_and_asset,
get_asset_by_hash,
get_reference_by_file_path,
@@ -23,7 +26,8 @@ from app.assets.database.queries import (
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import normalize_tags
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
@@ -130,6 +134,102 @@ def _ingest_file_from_path(
)
def register_output_files(
file_paths: Sequence[str],
user_metadata: UserMetadata = None,
job_id: str | None = None,
) -> int:
"""Register a batch of output file paths as assets.
Returns the number of files successfully registered.
"""
registered = 0
for abs_path in file_paths:
if not os.path.isfile(abs_path):
continue
try:
if ingest_existing_file(
abs_path, user_metadata=user_metadata, job_id=job_id
):
registered += 1
except Exception:
logging.exception("Failed to register output: %s", abs_path)
return registered
def ingest_existing_file(
abs_path: str,
user_metadata: UserMetadata = None,
extra_tags: Sequence[str] = (),
owner_id: str = "",
job_id: str | None = None,
) -> bool:
"""Register an existing on-disk file as an asset stub.
If a reference already exists for this path, updates mtime_ns, job_id,
size_bytes, and resets enrichment so the enricher will re-hash it.
For brand-new paths, inserts a stub record (hash=NULL) for immediate
UX visibility.
Returns True if a row was inserted or updated, False otherwise.
"""
locator = os.path.abspath(abs_path)
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
with create_session() as session:
existing_ref = get_reference_by_file_path(session, locator)
if existing_ref is not None:
now = get_utc_now()
existing_ref.mtime_ns = mtime_ns
existing_ref.job_id = job_id
existing_ref.is_missing = False
existing_ref.deleted_at = None
existing_ref.updated_at = now
existing_ref.enrichment_level = 0
asset = existing_ref.asset
if asset:
# If other refs share this asset, detach to a new stub
# instead of mutating the shared row.
siblings = count_active_siblings(session, asset.id, existing_ref.id)
if siblings > 0:
new_asset = create_stub_asset(
session,
size_bytes=size_bytes,
mime_type=mime_type or asset.mime_type,
)
existing_ref.asset_id = new_asset.id
else:
asset.hash = None
asset.size_bytes = size_bytes
if mime_type:
asset.mime_type = mime_type
session.commit()
return True
spec = {
"abs_path": abs_path,
"size_bytes": size_bytes,
"mtime_ns": mtime_ns,
"info_name": name,
"tags": tags,
"fname": os.path.basename(abs_path),
"metadata": None,
"hash": None,
"mime_type": mime_type,
"job_id": job_id,
}
if tags:
ensure_tags_exist(session, tags)
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
session.commit()
return result.won_paths > 0
def _register_existing_asset(
asset_hash: str,
name: str,

View File

@@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
) -> tuple[Literal["input", "output", "temp", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'temp': under folder_paths.get_temp_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
@@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
# 3) temp
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return "temp", _compute_relative(fp_abs, temp_base)
# 4) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
@@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)

View File

@@ -0,0 +1,90 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
uniform float u_float5;
uniform float u_float6;
uniform float u_float7;
uniform float u_float8;
uniform bool u_bool0;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 rgb2hsl(vec3 c) {
float maxC = max(c.r, max(c.g, c.b));
float minC = min(c.r, min(c.g, c.b));
float l = (maxC + minC) * 0.5;
if (maxC == minC) return vec3(0.0, 0.0, l);
float d = maxC - minC;
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
float h;
if (maxC == c.r) {
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / d + 2.0;
} else {
h = (c.r - c.g) / d + 4.0;
}
h /= 6.0;
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
if (t < 0.0) t += 1.0;
if (t > 1.0) t -= 1.0;
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
if (t < 1.0 / 2.0) return q;
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
float h = hsl.x, s = hsl.y, l = hsl.z;
if (s == 0.0) return vec3(l);
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
float p = 2.0 * l - q;
return vec3(
hue2rgb(p, q, h + 1.0 / 3.0),
hue2rgb(p, q, h),
hue2rgb(p, q, h - 1.0 / 3.0)
);
}
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float lightness = (maxC + minC) * 0.5;
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
const float a = 0.25;
const float b = 0.333;
const float scale = 0.7;
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
color += sw * shadows + mw * midtones + hw * highlights;
if (u_bool0) {
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
hsl.z = lightness;
color = hsl2rgb(hsl);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -0,0 +1,49 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
uniform sampler2D u_curve1; // Red channel curve
uniform sampler2D u_curve2; // Green channel curve
uniform sampler2D u_curve3; // Blue channel curve
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// GIMP-compatible curve lookup with manual linear interpolation.
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
// index = value * (n_samples - 1)
// f = fract(index)
// result = (1-f) * samples[floor] + f * samples[ceil]
//
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
float applyCurve(sampler2D curve, float value) {
value = clamp(value, 0.0, 1.0);
float pos = value * 255.0;
int lo = int(floor(pos));
int hi = min(lo + 1, 255);
float f = pos - float(lo);
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
return a + f * (b - a);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// GIMP order: per-channel curves first, then RGB master curve.
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
// dest = colors_curve( channel_curve( src ) )
float tmp_r = applyCurve(u_curve1, color.r);
float tmp_g = applyCurve(u_curve2, color.g);
float tmp_b = applyCurve(u_curve3, color.b);
color.r = applyCurve(u_curve0, tmp_r);
color.g = applyCurve(u_curve0, tmp_g);
color.b = applyCurve(u_curve0, tmp_b);
fragColor0 = vec4(color.rgb, color.a);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,615 @@
{
"revision": 0,
"last_node_id": 10,
"last_link_id": 0,
"nodes": [
{
"id": 10,
"type": "d5c462c8-1372-4af8-84f2-547c83470d04",
"pos": [
3610,
-2630
],
"size": [
270,
420
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"label": "image",
"localized_name": "images.image0",
"name": "images.image0",
"type": "IMAGE",
"link": null
}
],
"outputs": [
{
"label": "IMAGE",
"localized_name": "IMAGE0",
"name": "IMAGE0",
"type": "IMAGE",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"4",
"curve"
],
[
"5",
"curve"
],
[
"6",
"curve"
],
[
"7",
"curve"
]
]
},
"widgets_values": [],
"title": "Color Curves"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "d5c462c8-1372-4af8-84f2-547c83470d04",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 9,
"lastLinkId": 38,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Color Curves",
"inputNode": {
"id": -10,
"bounding": [
2660,
-4500,
120,
60
]
},
"outputNode": {
"id": -20,
"bounding": [
4270,
-4500,
120,
60
]
},
"inputs": [
{
"id": "abc345b7-f55e-4f32-a11d-3aa4c2b0936b",
"name": "images.image0",
"type": "IMAGE",
"linkIds": [
29,
34
],
"localized_name": "images.image0",
"label": "image",
"pos": [
2760,
-4480
]
}
],
"outputs": [
{
"id": "eb0ec079-46da-4408-8263-9ef85569d33d",
"name": "IMAGE0",
"type": "IMAGE",
"linkIds": [
28
],
"localized_name": "IMAGE0",
"label": "IMAGE",
"pos": [
4290,
-4480
]
}
],
"widgets": [],
"nodes": [
{
"id": 4,
"type": "CurveEditor",
"pos": [
3060,
-4500
],
"size": [
270,
200
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"label": "curve",
"localized_name": "curve",
"name": "curve",
"type": "CURVE",
"widget": {
"name": "curve"
},
"link": null
},
{
"label": "histogram",
"localized_name": "histogram",
"name": "histogram",
"type": "HISTOGRAM",
"shape": 7,
"link": 35
}
],
"outputs": [
{
"localized_name": "CURVE",
"name": "CURVE",
"type": "CURVE",
"links": [
30
]
}
],
"title": "RGB Master",
"properties": {
"Node name for S&R": "CurveEditor"
},
"widgets_values": []
},
{
"id": 5,
"type": "CurveEditor",
"pos": [
3060,
-4250
],
"size": [
270,
200
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"label": "curve",
"localized_name": "curve",
"name": "curve",
"type": "CURVE",
"widget": {
"name": "curve"
},
"link": null
},
{
"label": "histogram",
"localized_name": "histogram",
"name": "histogram",
"type": "HISTOGRAM",
"shape": 7,
"link": 36
}
],
"outputs": [
{
"localized_name": "CURVE",
"name": "CURVE",
"type": "CURVE",
"links": [
31
]
}
],
"title": "Red",
"properties": {
"Node name for S&R": "CurveEditor"
},
"widgets_values": []
},
{
"id": 6,
"type": "CurveEditor",
"pos": [
3060,
-4000
],
"size": [
270,
200
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"label": "curve",
"localized_name": "curve",
"name": "curve",
"type": "CURVE",
"widget": {
"name": "curve"
},
"link": null
},
{
"label": "histogram",
"localized_name": "histogram",
"name": "histogram",
"type": "HISTOGRAM",
"shape": 7,
"link": 37
}
],
"outputs": [
{
"localized_name": "CURVE",
"name": "CURVE",
"type": "CURVE",
"links": [
32
]
}
],
"title": "Green",
"properties": {
"Node name for S&R": "CurveEditor"
},
"widgets_values": []
},
{
"id": 7,
"type": "CurveEditor",
"pos": [
3060,
-3750
],
"size": [
270,
200
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"label": "curve",
"localized_name": "curve",
"name": "curve",
"type": "CURVE",
"widget": {
"name": "curve"
},
"link": null
},
{
"label": "histogram",
"localized_name": "histogram",
"name": "histogram",
"type": "HISTOGRAM",
"shape": 7,
"link": 38
}
],
"outputs": [
{
"localized_name": "CURVE",
"name": "CURVE",
"type": "CURVE",
"links": [
33
]
}
],
"title": "Blue",
"properties": {
"Node name for S&R": "CurveEditor"
},
"widgets_values": []
},
{
"id": 8,
"type": "GLSLShader",
"pos": [
3590,
-4500
],
"size": [
420,
500
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"label": "image0",
"localized_name": "images.image0",
"name": "images.image0",
"type": "IMAGE",
"link": 29
},
{
"label": "image1",
"localized_name": "images.image1",
"name": "images.image1",
"shape": 7,
"type": "IMAGE",
"link": null
},
{
"label": "u_curve0",
"localized_name": "curves.u_curve0",
"name": "curves.u_curve0",
"shape": 7,
"type": "CURVE",
"link": 30
},
{
"label": "u_curve1",
"localized_name": "curves.u_curve1",
"name": "curves.u_curve1",
"shape": 7,
"type": "CURVE",
"link": 31
},
{
"label": "u_curve2",
"localized_name": "curves.u_curve2",
"name": "curves.u_curve2",
"shape": 7,
"type": "CURVE",
"link": 32
},
{
"label": "u_curve3",
"localized_name": "curves.u_curve3",
"name": "curves.u_curve3",
"shape": 7,
"type": "CURVE",
"link": 33
},
{
"localized_name": "fragment_shader",
"name": "fragment_shader",
"type": "STRING",
"widget": {
"name": "fragment_shader"
},
"link": null
},
{
"localized_name": "size_mode",
"name": "size_mode",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "size_mode"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE0",
"name": "IMAGE0",
"type": "IMAGE",
"links": [
28
]
},
{
"localized_name": "IMAGE1",
"name": "IMAGE1",
"type": "IMAGE",
"links": null
},
{
"localized_name": "IMAGE2",
"name": "IMAGE2",
"type": "IMAGE",
"links": null
},
{
"localized_name": "IMAGE3",
"name": "IMAGE3",
"type": "IMAGE",
"links": null
}
],
"properties": {
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform sampler2D u_curve0; // RGB master curve (256x1 LUT)\nuniform sampler2D u_curve1; // Red channel curve\nuniform sampler2D u_curve2; // Green channel curve\nuniform sampler2D u_curve3; // Blue channel curve\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\n// GIMP-compatible curve lookup with manual linear interpolation.\n// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:\n// index = value * (n_samples - 1)\n// f = fract(index)\n// result = (1-f) * samples[floor] + f * samples[ceil]\n//\n// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues\n// that occur with texture() + GL_LINEAR on small 256x1 LUTs.\nfloat applyCurve(sampler2D curve, float value) {\n value = clamp(value, 0.0, 1.0);\n\n float pos = value * 255.0;\n int lo = int(floor(pos));\n int hi = min(lo + 1, 255);\n float f = pos - float(lo);\n\n float a = texelFetch(curve, ivec2(lo, 0), 0).r;\n float b = texelFetch(curve, ivec2(hi, 0), 0).r;\n\n return a + f * (b - a);\n}\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n\n // GIMP order: per-channel curves first, then RGB master curve.\n // See gimp_curve_map_pixels() default case in gimpcurve-map.c:\n // dest = colors_curve( channel_curve( src ) )\n float tmp_r = applyCurve(u_curve1, color.r);\n float tmp_g = applyCurve(u_curve2, color.g);\n float tmp_b = applyCurve(u_curve3, color.b);\n color.r = applyCurve(u_curve0, tmp_r);\n color.g = applyCurve(u_curve0, tmp_g);\n color.b = applyCurve(u_curve0, tmp_b);\n\n fragColor0 = vec4(color.rgb, color.a);\n}\n",
"from_input"
]
},
{
"id": 9,
"type": "ImageHistogram",
"pos": [
2800,
-4300
],
"size": [
210,
150
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"label": "image",
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 34
}
],
"outputs": [
{
"localized_name": "HISTOGRAM",
"name": "rgb",
"type": "HISTOGRAM",
"links": [
35
]
},
{
"localized_name": "HISTOGRAM",
"name": "luminance",
"type": "HISTOGRAM",
"links": []
},
{
"localized_name": "HISTOGRAM",
"name": "red",
"type": "HISTOGRAM",
"links": [
36
]
},
{
"localized_name": "HISTOGRAM",
"name": "green",
"type": "HISTOGRAM",
"links": [
37
]
},
{
"localized_name": "HISTOGRAM",
"name": "blue",
"type": "HISTOGRAM",
"links": [
38
]
}
],
"properties": {
"Node name for S&R": "ImageHistogram"
},
"widgets_values": []
}
],
"groups": [],
"links": [
{
"id": 29,
"origin_id": -10,
"origin_slot": 0,
"target_id": 8,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 28,
"origin_id": 8,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 30,
"origin_id": 4,
"origin_slot": 0,
"target_id": 8,
"target_slot": 2,
"type": "CURVE"
},
{
"id": 31,
"origin_id": 5,
"origin_slot": 0,
"target_id": 8,
"target_slot": 3,
"type": "CURVE"
},
{
"id": 32,
"origin_id": 6,
"origin_slot": 0,
"target_id": 8,
"target_slot": 4,
"type": "CURVE"
},
{
"id": 33,
"origin_id": 7,
"origin_slot": 0,
"target_id": 8,
"target_slot": 5,
"type": "CURVE"
},
{
"id": 34,
"origin_id": -10,
"origin_slot": 0,
"target_id": 9,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 35,
"origin_id": 9,
"origin_slot": 0,
"target_id": 4,
"target_slot": 1,
"type": "HISTOGRAM"
},
{
"id": 36,
"origin_id": 9,
"origin_slot": 2,
"target_id": 5,
"target_slot": 1,
"type": "HISTOGRAM"
},
{
"id": 37,
"origin_id": 9,
"origin_slot": 3,
"target_id": 6,
"target_slot": 1,
"type": "HISTOGRAM"
},
{
"id": 38,
"origin_id": 9,
"origin_slot": 4,
"target_id": 7,
"target_slot": 1,
"type": "HISTOGRAM"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
}
]
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +1,322 @@
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
{
"revision": 0,
"last_node_id": 29,
"last_link_id": 0,
"nodes": [
{
"id": 29,
"type": "4c9d6ea4-b912-40e5-8766-6793a9758c53",
"pos": [
1970,
-230
],
"size": [
180,
86
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"label": "image",
"localized_name": "images.image0",
"name": "images.image0",
"type": "IMAGE",
"link": null
}
],
"outputs": [
{
"label": "R",
"localized_name": "IMAGE0",
"name": "IMAGE0",
"type": "IMAGE",
"links": []
},
{
"label": "G",
"localized_name": "IMAGE1",
"name": "IMAGE1",
"type": "IMAGE",
"links": []
},
{
"label": "B",
"localized_name": "IMAGE2",
"name": "IMAGE2",
"type": "IMAGE",
"links": []
},
{
"label": "A",
"localized_name": "IMAGE3",
"name": "IMAGE3",
"type": "IMAGE",
"links": []
}
],
"title": "Image Channels",
"properties": {
"proxyWidgets": []
},
"widgets_values": []
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 28,
"lastLinkId": 39,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image Channels",
"inputNode": {
"id": -10,
"bounding": [
1820,
-185,
120,
60
]
},
"outputNode": {
"id": -20,
"bounding": [
2460,
-215,
120,
120
]
},
"inputs": [
{
"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe",
"name": "images.image0",
"type": "IMAGE",
"linkIds": [
39
],
"localized_name": "images.image0",
"label": "image",
"pos": [
1920,
-165
]
}
],
"outputs": [
{
"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b",
"name": "IMAGE0",
"type": "IMAGE",
"linkIds": [
26
],
"localized_name": "IMAGE0",
"label": "R",
"pos": [
2480,
-195
]
},
{
"id": "fb44a77e-0522-43e9-9527-82e7465b3596",
"name": "IMAGE1",
"type": "IMAGE",
"linkIds": [
27
],
"localized_name": "IMAGE1",
"label": "G",
"pos": [
2480,
-175
]
},
{
"id": "81460ee6-0131-402a-874f-6bf3001fc4ff",
"name": "IMAGE2",
"type": "IMAGE",
"linkIds": [
28
],
"localized_name": "IMAGE2",
"label": "B",
"pos": [
2480,
-155
]
},
{
"id": "ae690246-80d4-4951-b1d9-9306d8a77417",
"name": "IMAGE3",
"type": "IMAGE",
"linkIds": [
29
],
"localized_name": "IMAGE3",
"label": "A",
"pos": [
2480,
-135
]
}
],
"widgets": [],
"nodes": [
{
"id": 23,
"type": "GLSLShader",
"pos": [
2000,
-330
],
"size": [
400,
172
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"label": "image",
"localized_name": "images.image0",
"name": "images.image0",
"type": "IMAGE",
"link": 39
},
{
"localized_name": "fragment_shader",
"name": "fragment_shader",
"type": "STRING",
"widget": {
"name": "fragment_shader"
},
"link": null
},
{
"localized_name": "size_mode",
"name": "size_mode",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "size_mode"
},
"link": null
},
{
"label": "image1",
"localized_name": "images.image1",
"name": "images.image1",
"shape": 7,
"type": "IMAGE",
"link": null
}
],
"outputs": [
{
"label": "R",
"localized_name": "IMAGE0",
"name": "IMAGE0",
"type": "IMAGE",
"links": [
26
]
},
{
"label": "G",
"localized_name": "IMAGE1",
"name": "IMAGE1",
"type": "IMAGE",
"links": [
27
]
},
{
"label": "B",
"localized_name": "IMAGE2",
"name": "IMAGE2",
"type": "IMAGE",
"links": [
28
]
},
{
"label": "A",
"localized_name": "IMAGE3",
"name": "IMAGE3",
"type": "IMAGE",
"links": [
29
]
}
],
"properties": {
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n",
"from_input"
]
}
],
"groups": [],
"links": [
{
"id": 39,
"origin_id": -10,
"origin_slot": 0,
"target_id": 23,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 26,
"origin_id": 23,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 27,
"origin_id": 23,
"origin_slot": 1,
"target_id": -20,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 28,
"origin_id": 23,
"origin_slot": 2,
"target_id": -20,
"target_slot": 2,
"type": "IMAGE"
},
{
"id": 29,
"origin_id": 23,
"origin_slot": 3,
"target_id": -20,
"target_slot": 3,
"type": "IMAGE"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
}
]
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +1,278 @@
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
{
"revision": 0,
"last_node_id": 15,
"last_link_id": 0,
"nodes": [
{
"id": 15,
"type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471",
"pos": [
-1490,
2040
],
"size": [
400,
260
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"name": "prompt",
"type": "STRING",
"widget": {
"name": "prompt"
},
"link": null
},
{
"label": "reference images",
"name": "images",
"type": "IMAGE",
"link": null
}
],
"outputs": [
{
"name": "STRING",
"type": "STRING",
"links": null
}
],
"title": "Prompt Enhance",
"properties": {
"proxyWidgets": [
[
"-1",
"prompt"
]
],
"cnr_id": "comfy-core",
"ver": "0.14.1"
},
"widgets_values": [
""
]
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 15,
"lastLinkId": 14,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Prompt Enhance",
"inputNode": {
"id": -10,
"bounding": [
-2170,
2110,
138.876953125,
80
]
},
"outputNode": {
"id": -20,
"bounding": [
-640,
2110,
120,
60
]
},
"inputs": [
{
"id": "aeab7216-00e0-4528-a09b-bba50845c5a6",
"name": "prompt",
"type": "STRING",
"linkIds": [
11
],
"pos": [
-2051.123046875,
2130
]
},
{
"id": "7b73fd36-aa31-4771-9066-f6c83879994b",
"name": "images",
"type": "IMAGE",
"linkIds": [
14
],
"label": "reference images",
"pos": [
-2051.123046875,
2150
]
}
],
"outputs": [
{
"id": "c7b0d930-68a1-48d1-b496-0519e5837064",
"name": "STRING",
"type": "STRING",
"linkIds": [
13
],
"pos": [
-620,
2130
]
}
],
"widgets": [],
"nodes": [
{
"id": 11,
"type": "GeminiNode",
"pos": [
-1560,
1990
],
"size": [
470,
470
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "images",
"name": "images",
"shape": 7,
"type": "IMAGE",
"link": 14
},
{
"localized_name": "audio",
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"localized_name": "video",
"name": "video",
"shape": 7,
"type": "VIDEO",
"link": null
},
{
"localized_name": "files",
"name": "files",
"shape": 7,
"type": "GEMINI_INPUT_FILES",
"link": null
},
{
"localized_name": "prompt",
"name": "prompt",
"type": "STRING",
"widget": {
"name": "prompt"
},
"link": 11
},
{
"localized_name": "model",
"name": "model",
"type": "COMBO",
"widget": {
"name": "model"
},
"link": null
},
{
"localized_name": "seed",
"name": "seed",
"type": "INT",
"widget": {
"name": "seed"
},
"link": null
},
{
"localized_name": "system_prompt",
"name": "system_prompt",
"shape": 7,
"type": "STRING",
"widget": {
"name": "system_prompt"
},
"link": null
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
13
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.14.1",
"Node name for S&R": "GeminiNode"
},
"widgets_values": [
"",
"gemini-3-pro-preview",
42,
"randomize",
"You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"
],
"color": "#432",
"bgcolor": "#653"
}
],
"groups": [],
"links": [
{
"id": 11,
"origin_id": -10,
"origin_slot": 0,
"target_id": 11,
"target_slot": 4,
"type": "STRING"
},
{
"id": 13,
"origin_id": 11,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "STRING"
},
{
"id": 14,
"origin_id": -10,
"origin_slot": 1,
"target_id": 11,
"target_slot": 0,
"type": "IMAGE"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Prompt enhance"
}
]
},
"extra": {}
}

View File

@@ -1 +1,309 @@
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
{
"revision": 0,
"last_node_id": 25,
"last_link_id": 0,
"nodes": [
{
"id": 25,
"type": "621ba4e2-22a8-482d-a369-023753198b7b",
"pos": [
4610,
-790
],
"size": [
230,
58
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"label": "image",
"localized_name": "images.image0",
"name": "images.image0",
"type": "IMAGE",
"link": null
}
],
"outputs": [
{
"label": "IMAGE",
"localized_name": "IMAGE0",
"name": "IMAGE0",
"type": "IMAGE",
"links": []
}
],
"title": "Sharpen",
"properties": {
"proxyWidgets": [
[
"24",
"value"
]
]
},
"widgets_values": []
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "621ba4e2-22a8-482d-a369-023753198b7b",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 24,
"lastLinkId": 36,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Sharpen",
"inputNode": {
"id": -10,
"bounding": [
4090,
-825,
120,
60
]
},
"outputNode": {
"id": -20,
"bounding": [
5150,
-825,
120,
60
]
},
"inputs": [
{
"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7",
"name": "images.image0",
"type": "IMAGE",
"linkIds": [
34
],
"localized_name": "images.image0",
"label": "image",
"pos": [
4190,
-805
]
}
],
"outputs": [
{
"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9",
"name": "IMAGE0",
"type": "IMAGE",
"linkIds": [
35
],
"localized_name": "IMAGE0",
"label": "IMAGE",
"pos": [
5170,
-805
]
}
],
"widgets": [],
"nodes": [
{
"id": 24,
"type": "PrimitiveFloat",
"pos": [
4280,
-1240
],
"size": [
270,
58
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"label": "strength",
"localized_name": "value",
"name": "value",
"type": "FLOAT",
"widget": {
"name": "value"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": [
36
]
}
],
"properties": {
"Node name for S&R": "PrimitiveFloat",
"min": 0,
"max": 3,
"precision": 2,
"step": 0.05
},
"widgets_values": [
0.5
]
},
{
"id": 23,
"type": "GLSLShader",
"pos": [
4570,
-1240
],
"size": [
370,
192
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"label": "image0",
"localized_name": "images.image0",
"name": "images.image0",
"type": "IMAGE",
"link": 34
},
{
"label": "image1",
"localized_name": "images.image1",
"name": "images.image1",
"shape": 7,
"type": "IMAGE",
"link": null
},
{
"label": "u_float0",
"localized_name": "floats.u_float0",
"name": "floats.u_float0",
"shape": 7,
"type": "FLOAT",
"link": 36
},
{
"label": "u_float1",
"localized_name": "floats.u_float1",
"name": "floats.u_float1",
"shape": 7,
"type": "FLOAT",
"link": null
},
{
"label": "u_int0",
"localized_name": "ints.u_int0",
"name": "ints.u_int0",
"shape": 7,
"type": "INT",
"link": null
},
{
"localized_name": "fragment_shader",
"name": "fragment_shader",
"type": "STRING",
"widget": {
"name": "fragment_shader"
},
"link": null
},
{
"localized_name": "size_mode",
"name": "size_mode",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "size_mode"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE0",
"name": "IMAGE0",
"type": "IMAGE",
"links": [
35
]
},
{
"localized_name": "IMAGE1",
"name": "IMAGE1",
"type": "IMAGE",
"links": null
},
{
"localized_name": "IMAGE2",
"name": "IMAGE2",
"type": "IMAGE",
"links": null
},
{
"localized_name": "IMAGE3",
"name": "IMAGE3",
"type": "IMAGE",
"links": null
}
],
"properties": {
"Node name for S&R": "GLSLShader"
},
"widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
"from_input"
]
}
],
"groups": [],
"links": [
{
"id": 36,
"origin_id": 24,
"origin_slot": 0,
"target_id": 23,
"target_slot": 2,
"type": "FLOAT"
},
{
"id": 34,
"origin_id": -10,
"origin_slot": 0,
"target_id": 23,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 35,
"origin_id": 23,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Sharpen"
}
]
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +1,420 @@
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
{
"revision": 0,
"last_node_id": 13,
"last_link_id": 0,
"nodes": [
{
"id": 13,
"type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1",
"pos": [
1120,
330
],
"size": [
240,
58
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": null
},
{
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "VIDEO",
"name": "VIDEO",
"type": "VIDEO",
"links": []
}
],
"title": "Video Upscale(GAN x4)",
"properties": {
"proxyWidgets": [
[
"-1",
"model_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.14.1"
},
"widgets_values": [
"RealESRGAN_x4plus.safetensors"
]
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 13,
"lastLinkId": 19,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Video Upscale(GAN x4)",
"inputNode": {
"id": -10,
"bounding": [
550,
460,
120,
80
]
},
"outputNode": {
"id": -20,
"bounding": [
1490,
460,
120,
60
]
},
"inputs": [
{
"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6",
"name": "video",
"type": "VIDEO",
"linkIds": [
10
],
"localized_name": "video",
"pos": [
650,
480
]
},
{
"id": "2e23a087-caa8-4d65-99e6-662761aa905a",
"name": "model_name",
"type": "COMBO",
"linkIds": [
19
],
"pos": [
650,
500
]
}
],
"outputs": [
{
"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70",
"name": "VIDEO",
"type": "VIDEO",
"linkIds": [
15
],
"localized_name": "VIDEO",
"pos": [
1510,
480
]
}
],
"widgets": [],
"nodes": [
{
"id": 2,
"type": "ImageUpscaleWithModel",
"pos": [
1110,
450
],
"size": [
320,
46
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "upscale_model",
"name": "upscale_model",
"type": "UPSCALE_MODEL",
"link": 1
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 14
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
13
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.10.0",
"Node name for S&R": "ImageUpscaleWithModel"
}
},
{
"id": 11,
"type": "CreateVideo",
"pos": [
1110,
550
],
"size": [
320,
78
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"link": 13
},
{
"localized_name": "audio",
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": 16
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"widget": {
"name": "fps"
},
"link": 12
}
],
"outputs": [
{
"localized_name": "VIDEO",
"name": "VIDEO",
"type": "VIDEO",
"links": [
15
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.10.0",
"Node name for S&R": "CreateVideo"
},
"widgets_values": [
30
]
},
{
"id": 10,
"type": "GetVideoComponents",
"pos": [
1110,
330
],
"size": [
320,
70
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": 10
}
],
"outputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"links": [
14
]
},
{
"localized_name": "audio",
"name": "audio",
"type": "AUDIO",
"links": [
16
]
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"links": [
12
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.10.0",
"Node name for S&R": "GetVideoComponents"
}
},
{
"id": 1,
"type": "UpscaleModelLoader",
"pos": [
750,
450
],
"size": [
280,
60
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model_name",
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": 19
}
],
"outputs": [
{
"localized_name": "UPSCALE_MODEL",
"name": "UPSCALE_MODEL",
"type": "UPSCALE_MODEL",
"links": [
1
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.10.0",
"Node name for S&R": "UpscaleModelLoader",
"models": [
{
"name": "RealESRGAN_x4plus.safetensors",
"url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors",
"directory": "upscale_models"
}
]
},
"widgets_values": [
"RealESRGAN_x4plus.safetensors"
]
}
],
"groups": [],
"links": [
{
"id": 1,
"origin_id": 1,
"origin_slot": 0,
"target_id": 2,
"target_slot": 0,
"type": "UPSCALE_MODEL"
},
{
"id": 14,
"origin_id": 10,
"origin_slot": 0,
"target_id": 2,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 13,
"origin_id": 2,
"origin_slot": 0,
"target_id": 11,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 16,
"origin_id": 10,
"origin_slot": 1,
"target_id": 11,
"target_slot": 1,
"type": "AUDIO"
},
{
"id": 12,
"origin_id": 10,
"origin_slot": 2,
"target_id": 11,
"target_slot": 2,
"type": "FLOAT"
},
{
"id": 10,
"origin_id": -10,
"origin_slot": 0,
"target_id": 10,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 15,
"origin_id": 11,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 19,
"origin_id": -10,
"origin_slot": 1,
"target_id": 1,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Enhance video"
}
]
},
"extra": {}
}

View File

@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
@@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet = None
self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -906,6 +949,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -611,6 +611,7 @@ class AceStepDiTModel(nn.Module):
intermediate_size,
patch_size,
audio_acoustic_hidden_dim,
condition_dim=None,
layer_types=None,
sliding_window=128,
rms_norm_eps=1e-6,
@@ -640,7 +641,7 @@ class AceStepDiTModel(nn.Module):
self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.condition_embedder = Linear(condition_dim, hidden_size, dtype=dtype, device=device)
if layer_types is None:
layer_types = ["full_attention"] * num_layers
@@ -1035,6 +1036,9 @@ class AceStepConditionGenerationModel(nn.Module):
fsq_dim=2048,
fsq_levels=[8, 8, 8, 5, 5, 5],
fsq_input_num_quantizers=1,
encoder_hidden_size=2048,
encoder_intermediate_size=6144,
encoder_num_heads=16,
audio_model=None,
dtype=None,
device=None,
@@ -1054,24 +1058,24 @@ class AceStepConditionGenerationModel(nn.Module):
self.decoder = AceStepDiTModel(
in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim,
intermediate_size, patch_size, audio_acoustic_hidden_dim,
intermediate_size, patch_size, audio_acoustic_hidden_dim, condition_dim=encoder_hidden_size,
layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps,
dtype=dtype, device=device, operations=operations
)
self.encoder = AceStepConditionEncoder(
text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers,
num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
text_hidden_dim, timbre_hidden_dim, encoder_hidden_size, num_lyric_layers, num_timbre_layers,
encoder_num_heads, num_kv_heads, head_dim, encoder_intermediate_size, rms_norm_eps,
dtype=dtype, device=device, operations=operations
)
self.tokenizer = AceStepAudioTokenizer(
audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
audio_acoustic_hidden_dim, encoder_hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
dtype=dtype, device=device, operations=operations
)
self.detokenizer = AudioTokenDetokenizer(
hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
encoder_hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
dtype=dtype, device=device, operations=operations
)
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, encoder_hidden_size, dtype=dtype, device=device))
def prepare_condition(
self,

View File

@@ -386,7 +386,7 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])

View File

@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
# Inject reference audio for ID-LoRA in-context conditioning
ref_audio = kwargs.get("ref_audio", None)
ref_audio_seq_len = 0
if ref_audio is not None:
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
if ref_tokens.shape[0] < ax.shape[0]:
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
ref_audio_seq_len = ref_tokens.shape[1]
B = ax.shape[0]
# Compute negative temporal positions matching ID-LoRA convention:
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
p = self.a_patchifier
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
time_offset = ref_end[-1].item() + tpl
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["target_audio_seq_len"] = ax.shape[1]
ax = torch.cat([ref_tokens, ax], dim=1)
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0 and a_timestep is not None:
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
target_len = kwargs.get("target_audio_seq_len")
if a_timestep.dim() <= 1:
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
# Trim reference audio tokens before unpatchification
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0:
ax = ax[:, ref_audio_seq_len:]
if a_embedded_timestep.shape[1] > 1:
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()

View File

@@ -155,6 +155,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
decoder_ddconfig = kwargs.pop("decoder_ddconfig", ddconfig)
super().__init__(
encoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
@@ -162,7 +163,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
decoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
"params": decoder_ddconfig,
},
**kwargs,
)

View File

@@ -3,12 +3,9 @@ from ..diffusionmodules.openaimodel import Timestep
import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
def __init__(self, *args, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs)
if clip_stats_path is None:
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)

View File

@@ -0,0 +1,725 @@
from collections import OrderedDict
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
COCO_CLASSES = [
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
]
# ---------------------------------------------------------------------------
# HGNetv2 backbone
# ---------------------------------------------------------------------------
class ConvBNAct(nn.Module):
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
super().__init__()
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
self.act = nn.ReLU() if use_act else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class LightConvBNAct(nn.Module):
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.conv2(self.conv1(x))
class _StemBlock(nn.Module):
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
super().__init__()
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
# padding is applied manually in forward (matching PaddlePaddle original)
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
def forward(self, x):
x = self.stem1(x)
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
x2 = self.stem2a(x)
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
x2 = self.stem2b(x2)
x1 = self.pool(x)
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
class _HG_Block(nn.Module):
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
super().__init__()
self.residual = residual
if light:
self.layers = nn.ModuleList(
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
else:
self.layers = nn.ModuleList(
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
total = ic + layer_num * mc
self.aggregation = nn.Sequential(
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
def forward(self, x):
identity = x
outs = [x]
for layer in self.layers:
x = layer(x)
outs.append(x)
x = self.aggregation(torch.cat(outs, 1))
return x + identity if self.residual else x
class _HG_Stage(nn.Module):
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
super().__init__()
if downsample:
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
else:
self.downsample = nn.Identity()
self.blocks = nn.Sequential(*[
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
for i in range(num_blocks)
])
def forward(self, x):
return self.blocks(self.downsample(x))
class HGNetv2(nn.Module):
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
[128, 128, 512, 2, True, False, 3, 6],
[512, 256, 1024, 5, True, True, 5, 6],
[1024,512, 2048, 2, True, True, 5, 6]]
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
super().__init__()
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
self.return_idx = list(return_idx)
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.stem(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.return_idx:
outs.append(x)
return outs
# ---------------------------------------------------------------------------
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
# ---------------------------------------------------------------------------
class ConvNormLayer(nn.Module):
"""Conv→act (expects pre-fused BN weights)."""
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
super().__init__()
p = (k - 1) // 2 if padding is None else padding
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
def forward(self, x):
return self.act(self.conv(x))
class VGGBlock(nn.Module):
"""Rep-VGG block (expects pre-fused weights)."""
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
super().__init__()
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.conv(x))
class CSPLayer(nn.Module):
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
super().__init__()
h = int(oc * expansion)
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
def forward(self, x):
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
class RepNCSPELAN4(nn.Module):
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
super().__init__()
self.c = c3 // 2
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
def forward(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
class SCDown(nn.Module):
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
super().__init__()
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.cv2(self.cv1(x))
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, query, key, value, attn_mask=None):
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
return self.out_proj(out)
class _TransformerEncoderLayer(nn.Module):
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.activation = nn.GELU()
def forward(self, src, src_mask=None, pos_embed=None):
q = k = src if pos_embed is None else src + pos_embed
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = self.norm1(src + src2)
src2 = self.linear2(self.activation(self.linear1(src)))
return self.norm2(src + src2)
class _TransformerEncoder(nn.Module):
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, src, src_mask=None, pos_embed=None):
for layer in self.layers:
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
return src
class HybridEncoder(nn.Module):
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
self.in_channels = list(in_channels)
self.feat_strides = list(feat_strides)
self.hidden_dim = hidden_dim
self.use_encoder_idx = list(use_encoder_idx)
self.pe_temperature = pe_temperature
self.eval_spatial_size = eval_spatial_size
self.out_channels = [hidden_dim] * len(in_channels)
self.out_strides = list(feat_strides)
# channel projection (expects pre-fused weights)
self.input_proj = nn.ModuleList([
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
for ch in in_channels
])
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
self.encoder = nn.ModuleList([
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(len(use_encoder_idx))
])
nb = round(3 * depth_mult)
exp = expansion
# top-down FPN (dfine: lateral conv has no act)
self.lateral_convs = nn.ModuleList(
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
self.fpn_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
self.downsample_convs = nn.ModuleList(
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
for _ in range(len(in_channels) - 1)])
self.pan_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# cache positional embeddings for fixed spatial size
if eval_spatial_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pe = self._build_pe(eval_spatial_size[1] // stride,
eval_spatial_size[0] // stride,
hidden_dim, pe_temperature)
setattr(self, f'pos_embed{idx}', pe)
@staticmethod
def _build_pe(w, h, dim=256, temp=10000.):
assert dim % 4 == 0
gw = torch.arange(w, dtype=torch.float32)
gh = torch.arange(h, dtype=torch.float32)
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
pdim = dim // 4
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
ow = gw.flatten()[:, None] @ omega[None]
oh = gh.flatten()[:, None] @ omega[None]
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
for i, enc_idx in enumerate(self.use_encoder_idx):
h, w = proj[enc_idx].shape[2:]
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
for layer in self.encoder[i].layers:
src = layer(src, pos_embed=pe)
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
n = len(self.in_channels)
inner = [proj[-1]]
for k in range(n - 1, 0, -1):
j = n - 1 - k
top = self.lateral_convs[j](inner[0])
inner[0] = top
up = F.interpolate(top, scale_factor=2., mode='nearest')
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
outs = [inner[0]]
for k in range(n - 1):
outs.append(self.pan_blocks[k](
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
return outs
# ---------------------------------------------------------------------------
# Decoder — DFINETransformer
# ---------------------------------------------------------------------------
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
"""
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
attention_weights : [bs, Lq, n_head, sum(pts)]
"""
_, c = value[0].shape[:2] # bs*n_head, c
_, Lq, n_head, _, _ = sampling_locations.shape
bs = sampling_locations.shape[0]
n_h = n_head
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
sampled = []
for lvl, (h, w) in enumerate(spatial_shapes):
val_l = value[lvl].reshape(bs * n_h, c, h, w)
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
out = out.reshape(bs, n_h * c, Lq)
return out.permute(0, 2, 1) # [bs, Lq, hidden]
class MSDeformableAttention(nn.Module):
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim, self.num_heads = embed_dim, num_heads
self.head_dim = embed_dim // num_heads
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
self.num_points_list = pts
self.offset_scale = offset_scale
total = num_heads * sum(pts)
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
def forward(self, query, ref_pts, value, spatial_shapes):
bs, Lq = query.shape[:2]
offsets = self.sampling_offsets(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
attn_w = F.softmax(
self.attention_weights(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
scale = self.num_points_scale.to(query).unsqueeze(-1)
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
class Gate(nn.Module):
def __init__(self, d_model, device=None, dtype=None, operations=None):
super().__init__()
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x1, x2):
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
return self.norm(g1 * x1 + g2 * x2)
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
super().__init__()
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.activation = nn.ReLU()
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
q = k = target if query_pos is None else target + query_pos
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
target = self.norm1(target + t2)
t2 = self.cross_attn(
target if query_pos is None else target + query_pos,
ref_pts, value, spatial_shapes)
target = self.gateway(target, t2)
t2 = self.linear2(self.activation(self.linear1(target)))
target = self.norm3((target + t2).clamp(-65504, 65504))
return target
# ---------------------------------------------------------------------------
# FDR utilities
# ---------------------------------------------------------------------------
def weighting_function(reg_max, up, reg_scale):
"""Non-uniform weighting function W(n) for FDR box regression."""
ub1 = (abs(up[0]) * abs(reg_scale)).item()
ub2 = ub1 * 2
step = (ub1 + 1) ** (2 / (reg_max - 2))
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
vals = [-ub2] + left + [0] + right + [ub2]
return torch.tensor(vals, dtype=up.dtype, device=up.device)
def distance2bbox(points, distance, reg_scale):
"""Decode edge-distances → cxcywh boxes."""
rs = abs(reg_scale).to(dtype=points.dtype)
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
return torch.stack([x0, y0, x1_, y1_], -1)
class Integral(nn.Module):
"""Sum Pr(n)·W(n) over the distribution bins."""
def __init__(self, reg_max=32):
super().__init__()
self.reg_max = reg_max
def forward(self, x, project):
shape = x.shape
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
return x.reshape(list(shape[:-1]) + [-1])
class LQE(nn.Module):
"""Location Quality Estimator — refines class scores using corner distribution."""
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
super().__init__()
self.k, self.reg_max = k, reg_max
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
def forward(self, scores, pred_corners):
B, L, _ = pred_corners.shape
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
topk, _ = prob.topk(self.k, -1)
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
return scores + self.reg_conf(stat.reshape(B, L, -1))
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.nhead = nhead
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
self.layers = nn.ModuleList([
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx + 1)
])
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
def _value_op(self, memory, spatial_shapes):
"""Reshape memory to per-level value tensors for deformable attention."""
c = self.hidden_dim // self.nhead
split = [h * w for h, w in spatial_shapes]
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
# → [bs, n_head, c, sum_hw]
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
# reshape to [bs*n_head, c, h_l, w_l]
value = []
for lvl, (h, w) in enumerate(spatial_shapes):
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
ref_pts = F.sigmoid(ref_pts_unact)
output = target
output_detach = pred_corners_undetach = 0
dec_bboxes, dec_logits = [], []
for i, layer in enumerate(self.layers):
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
if i == 0:
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
ref_unact = torch.log(ref_unact / (1 - ref_unact))
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
ref_pts_initial = pre_bboxes.detach()
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
if i == self.eval_idx:
scores = score_head[i](output)
scores = self.lqe_layers[i](scores, pred_corners)
dec_bboxes.append(inter_ref_bbox)
dec_logits.append(scores)
break
pred_corners_undetach = pred_corners
ref_pts = inter_ref_bbox.detach()
output_detach = output.detach()
return torch.stack(dec_bboxes), torch.stack(dec_logits)
class DFINETransformer(nn.Module):
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
assert len(feat_strides) == len(feat_channels)
self.hidden_dim = hidden_dim
self.num_queries = num_queries
self.num_levels = num_levels
self.eps = eps
self.eval_spatial_size = eval_spatial_size
self.feat_strides = list(feat_strides)
for i in range(num_levels - len(feat_strides)):
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
# input projection (expects pre-fused weights)
self.input_proj = nn.ModuleList()
for ch in feat_channels:
if ch == hidden_dim:
self.input_proj.append(nn.Identity())
else:
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
in_ch = feat_channels[-1]
for i in range(num_levels - len(feat_channels)):
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
in_ch = hidden_dim
# FDR parameters (non-trainable placeholders, set from config)
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
self.enc_output = nn.Sequential(OrderedDict([
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.dec_score_head = nn.ModuleList(
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.dec_bbox_head = nn.ModuleList(
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx_ + 1)])
self.integral = Integral(reg_max)
if eval_spatial_size:
# Register as buffers so checkpoint values override the freshly-computed defaults
anchors, valid_mask = self._gen_anchors()
self.register_buffer('anchors', anchors)
self.register_buffer('valid_mask', valid_mask)
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
if spatial_shapes is None:
h0, w0 = self.eval_spatial_size
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
anchors = []
for lvl, (h, w) in enumerate(spatial_shapes):
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
anchors = torch.cat(anchors, 1).to(device)
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
return anchors, valid_mask
def _encoder_input(self, feats: List[torch.Tensor]):
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
for i in range(len(feats), self.num_levels):
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
flat, shapes = [], []
for f in proj:
_, _, h, w = f.shape
flat.append(f.flatten(2).permute(0, 2, 1))
shapes.append([h, w])
return torch.cat(flat, 1), shapes
def _decoder_input(self, memory: torch.Tensor):
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
if memory.shape[0] > 1:
anchors = anchors.repeat(memory.shape[0], 1, 1)
mem = valid_mask.to(memory) * memory
out_mem = self.enc_output(mem)
logits = self.enc_score_head(out_mem)
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
idx_e = idx.unsqueeze(-1)
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
return topk_mem.detach(), topk_ref.detach()
def forward(self, feats: List[torch.Tensor]):
memory, shapes = self._encoder_input(feats)
content, ref = self._decoder_input(memory)
out_bboxes, out_logits = self.decoder(
content, ref, memory, shapes,
self.dec_bbox_head, self.dec_score_head,
self.query_pos_head, self.pre_bbox_head, self.integral)
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
class RTv4(nn.Module):
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.device = device
self.dtype = dtype
self.operations = operations
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
self.num_classes = num_classes
self.num_queries = num_queries
self.load_device = comfy.model_management.get_torch_device()
def _forward(self, x: torch.Tensor):
return self.decoder(self.encoder(self.backbone(x)))
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
logits = outputs['pred_logits']
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
scores = F.sigmoid(logits)
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
labels = idx % self.num_classes
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
return self.postprocess(outputs, orig_size)

View File

@@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
return dest_views
aimdo_enabled = False
extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0
def set_ram_cache_release_state(callback, headroom):
global extra_ram_release_callback
global RAM_CACHE_HEADROOM
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)

View File

@@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.model_management
import comfy.patcher_extension
@@ -890,7 +891,7 @@ class Flux(BaseModel):
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
return kwargs.get("pooled_output", None)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -937,9 +938,10 @@ class LongCatImage(Flux):
transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts)
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", 512.0)
rope_opts.setdefault("shift_x", 512.0)
rope_opts.setdefault("shift_y", pe_len)
rope_opts.setdefault("shift_x", pe_len)
transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
@@ -1060,6 +1062,10 @@ class LTXAV(BaseModel):
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
ref_audio = kwargs.get("ref_audio", None)
if ref_audio is not None:
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@@ -1952,3 +1958,7 @@ class Kandinsky5Image(Kandinsky5):
def concat_cond(self, **kwargs):
return None
class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)

View File

@@ -696,6 +696,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["audio_model"] = "ace1.5"
head_dim = 128
dit_config["hidden_size"] = state_dict['{}decoder.layers.0.self_attn_norm.weight'.format(key_prefix)].shape[0]
dit_config["intermediate_size"] = state_dict['{}decoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0]
dit_config["num_heads"] = state_dict['{}decoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["encoder_hidden_size"] = state_dict['{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix)].shape[0]
dit_config["encoder_num_heads"] = state_dict['{}encoder.lyric_encoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["encoder_intermediate_size"] = state_dict['{}encoder.lyric_encoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0]
dit_config["num_dit_layers"] = count_blocks(state_dict_keys, '{}decoder.layers.'.format(key_prefix) + '{}.')
return dit_config
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
dit_config = {}
dit_config["image_model"] = "RT_DETR_v4"
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:

View File

@@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@@ -32,6 +33,11 @@ import comfy.memory_management
import comfy.utils
import comfy.quant_ops
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -55,6 +61,7 @@ total_vram = 0
# Training Related State
in_training = False
training_fp8_bwd = False
def get_supported_float8_types():
@@ -205,6 +212,25 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -493,9 +519,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models = []
current_loaded_models: list[LoadedModel] = []
def module_size(module):
module_mem = 0
@@ -528,7 +558,7 @@ def module_mmap_residency(module, free=False):
return mmap_touched_mem, module_mem
class LoadedModel:
def __init__(self, model):
def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -536,7 +566,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -547,6 +577,7 @@ class LoadedModel:
model = self._parent_model()
if model is not None:
self._set_model(model)
self.device = model.load_device
@property
def model(self):
@@ -668,7 +699,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if device is None or shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
@@ -678,8 +709,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
@@ -707,7 +738,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if len(unloaded_model) > 0:
soft_empty_cache()
else:
elif device is not None:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
@@ -1325,9 +1356,9 @@ MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
@@ -1402,8 +1433,6 @@ def unpin_memory(tensor):
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
else:
logging.warning("Unpin error.")
@@ -1780,7 +1809,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
def debug_memory_summary():
if is_amd() or is_nvidia():

View File

@@ -23,6 +23,7 @@ import inspect
import logging
import math
import uuid
import copy
from typing import Callable, Optional
import torch
@@ -75,12 +76,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -272,7 +276,10 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -300,9 +307,6 @@ class ModelPatcher:
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -329,6 +333,8 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@@ -387,19 +393,98 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
if self.cached_patcher_init is not None:
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
n.model = temp_model_patcher.model
else:
n.model = copy.deepcopy(n.model)
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@@ -1170,7 +1255,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -1224,9 +1309,13 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
callback(self, timestep, model_options, ignore_multigpu)
if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options, ignore_multigpu=True)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -1239,12 +1328,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1256,6 +1351,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):

230
comfy/multigpu.py Normal file
View File

@@ -0,0 +1,230 @@
from __future__ import annotations
import queue
import threading
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}
for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))
def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))
def get_result(self, device: torch.device):
return self._result_queues[device].get()
@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())
def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
# multigpu_models = model.get_additional_models_with_key("multigpu")
# new_multigpu_models = []
# for m in multigpu_models:
# if m.load_device in limit_extra_devices:
# new_multigpu_models.append(m)
# model.set_additional_models("multigpu", new_multigpu_models)
# persist skip_devices for use in sampling code
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
# model.model_options["multigpu_skip_devices"] = skip_devices
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@@ -777,8 +777,16 @@ from .quant_ops import (
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
When training_fp8_bwd is enabled:
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
- Cached input is FP8 (half the memory of bf16)
When training_fp8_bwd is disabled:
- Forward: quantize input per layout, use quantized matmul
- Backward: dequantize weight to compute_dtype, use standard matmul
"""
@staticmethod
@@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
# Quantize input for forward (same layout as weight)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
@@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
# Unflatten output to match original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
# Save for backward
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
if ctx.fp8_bwd:
# Cache FP8 quantized input — half the memory of bf16
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
ctx.q_input = q_input # already FP8, reuse
else:
# NVFP4 or other layout — quantize input to FP8 for backward
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
ctx.save_for_backward(weight)
else:
ctx.q_input = None
ctx.save_for_backward(input_float, weight)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
# Value casting — only difference between fp8 and non-fp8 paths
if ctx.fp8_bwd:
weight, = ctx.saved_tensors
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
weight_mm = weight
elif isinstance(weight, QuantizedTensor):
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
else:
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
input_mm = ctx.q_input
else:
weight_f = weight.to(compute_dtype)
input_float, weight = ctx.saved_tensors
# Standard tensors → torch.mm does regular matmul
grad_mm = grad_2d
if isinstance(weight, QuantizedTensor):
weight_mm = weight.dequantize().to(compute_dtype)
else:
weight_mm = weight.to(compute_dtype)
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
# Computation — same for both paths, dispatch handles the rest
grad_input = torch.mm(grad_mm, weight_mm)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
grad_weight = torch.mm(grad_mm.t(), input_mm)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
@@ -895,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
@@ -1001,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:

View File

@@ -3,6 +3,8 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

View File

@@ -2,6 +2,7 @@ import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import psutil
from comfy.cli_args import args
@@ -12,6 +13,11 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:

View File

@@ -20,7 +20,6 @@ try:
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")

View File

@@ -1,16 +1,18 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@@ -118,6 +120,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -142,7 +185,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -154,7 +198,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
real_model: BaseModel = model.model
return real_model, conds, models
@@ -200,3 +244,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
@@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@@ -212,7 +217,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -244,7 +251,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -345,6 +352,212 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = comfy.model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep.to(device)
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
# Submit all GPU work to pool threads
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -649,6 +862,8 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -891,7 +1106,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -900,18 +1117,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
if "patches" in transformer_options:
patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -921,8 +1137,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -930,7 +1146,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -985,16 +1200,31 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
# Create persistent thread pool for all GPU devices (main + extras)
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.model_patcher
import comfy.lora
@@ -279,9 +280,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -425,13 +423,13 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@@ -558,12 +556,19 @@ class VAE:
old_memory_used_decode = self.memory_used_decode
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
decoder_ch = sd['decoder.conv_in.weight'].shape[0] // ddconfig['ch_mult'][-1]
if decoder_ch != ddconfig['ch']:
decoder_ddconfig = ddconfig.copy()
decoder_ddconfig['ch'] = decoder_ch
else:
decoder_ddconfig = None
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1], **({"decoder_ddconfig": decoder_ddconfig} if decoder_ddconfig is not None else {}))
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': decoder_ddconfig if decoder_ddconfig is not None else ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
config = {}
param_key = None
@@ -839,9 +844,6 @@ class VAE:
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@@ -1228,6 +1230,11 @@ class TEModel(Enum):
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
QWEN35_08B = 23
QWEN35_2B = 24
QWEN35_4B = 25
QWEN35_9B = 26
QWEN35_27B = 27
def detect_te_model(sd):
@@ -1267,6 +1274,17 @@ def detect_te_model(sd):
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
weight = sd['model.language_model.layers.0.input_layernorm.weight']
if weight.shape[0] == 1024:
return TEModel.QWEN35_08B
if weight.shape[0] == 2560:
return TEModel.QWEN35_4B
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1299,11 +1317,12 @@ def t5xxl_detect(clip_data):
return {}
def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
for weight_name in weight_names:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
@@ -1431,6 +1450,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
@@ -1572,10 +1596,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
@@ -1719,15 +1740,18 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
"""
dtype = model_options.get("dtype", None)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
sd = temp_sd
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)

View File

@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
def parse_parentheses(string):
result = []
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)

View File

@@ -1734,6 +1734,21 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
class RT_DETR_v4(supported_models_base.BASE):
unet_config = {
"image_model": "RT_DETR_v4",
}
supported_inference_dtypes = [torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.RT_DETR_v4(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
models += [SVD_img2vid]

View File

@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
lm_head: bool = True
stop_tokens = [151643, 151645]
@dataclass
@@ -655,6 +655,17 @@ class Llama2_(nn.Module):
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
return past_key_values[0][2]
def compute_freqs_cis(self, position_ids, device):
return precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
@@ -667,17 +678,12 @@ class Llama2_(nn.Module):
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]
past_len = self.get_past_len(past_key_values)
if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
mask = None
if attention_mask is not None:
@@ -812,9 +818,16 @@ class BaseGenerate:
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
device = embeds.device
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
@@ -829,11 +842,8 @@ class BaseGenerate:
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
@@ -844,7 +854,7 @@ class BaseGenerate:
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
@@ -856,7 +866,7 @@ class BaseGenerate:
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
@@ -867,6 +877,11 @@ class BaseGenerate:
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if presence_penalty is not None and presence_penalty != 0.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] -= presence_penalty
if temperature != 1.0:
logits = logits / temperature
@@ -897,6 +912,9 @@ class BaseGenerate:
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
if self.model.config.lm_head:
return self.model.lm_head(input)
module = self.model.embed_tokens
offload_stream = None
@@ -1028,12 +1046,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
if attention_mask is not None:
# Assign compact sequential positions to attended tokens only,
# skipping over padding so post-padding tokens aren't inflated.
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
else:
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]

View File

@@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
return [output]
IMAGE_PAD_TOKEN_ID = 151655
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
@@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer,
)
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
skip_template = False
if text.startswith("<|im_start|>"):
skip_template = True
@@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
else:
has_images = images is not None and len(images) > 0
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
prefix_ids = base_tok.tokenizer(
self.longcat_template_prefix, add_special_tokens=False
template_prefix, add_special_tokens=False
)["input_ids"]
suffix_ids = base_tok.tokenizer(
self.longcat_template_suffix, add_special_tokens=False
self.SUFFIX, add_special_tokens=False
)["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights(
@@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs
if has_images:
embed_count = 0
for i in range(len(combined)):
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
embed_count += 1
tokens = {"qwen25_7b": [combined]}
return tokens

View File

@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:

View File

@@ -0,0 +1,833 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
import os
import math
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy import sd1_clip
import comfy.text_encoders.qwen_vl
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
def _qwen35_layer_types(n):
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
@dataclass
class Qwen35Config:
vocab_size: int = 248320
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 24
# Full attention params
num_attention_heads: int = 8
num_key_value_heads: int = 2
head_dim: int = 256
partial_rotary_factor: float = 0.25
# Linear attention (DeltaNet) params
linear_num_key_heads: int = 16
linear_num_value_heads: int = 16
linear_key_head_dim: int = 128
linear_value_head_dim: int = 128
conv_kernel_size: int = 4
# Shared params
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 10000000.0
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
rms_norm_add: bool = True
mlp_activation: str = "silu"
qkv_bias: bool = False
final_norm: bool = True
lm_head: bool = False
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
transformer_type: str = "qwen35_2b"
rope_dims: list = None
rope_scale: float = None
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
QWEN35_MODELS = {
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
}
def _make_config(model_type, config_dict={}):
overrides = QWEN35_MODELS.get(model_type, {}).copy()
overrides.pop("vision", None)
if "num_hidden_layers" in overrides:
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
overrides.update(config_dict)
return Qwen35Config(**overrides)
class RMSNormGated(RMSNorm):
def forward(self, x, gate):
return super().forward(x) * F.silu(gate.to(x.dtype))
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
initial_dtype = query.dtype
query = F.normalize(query, dim=-1)
key = F.normalize(key, dim=-1)
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
# weight: [channels, kernel_size]
state_len = conv_state.shape[-1]
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
conv_state.copy_(combined[:, :, -state_len:])
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
if bias is not None:
out = out + bias.unsqueeze(0).unsqueeze(-1)
return F.silu(out).to(x.dtype)
# GatedDeltaNet - Linear Attention Layer
class GatedDeltaNet(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden = config.hidden_size
self.num_key_heads = config.linear_num_key_heads
self.num_value_heads = config.linear_num_value_heads
self.key_head_dim = config.linear_key_head_dim
self.value_head_dim = config.linear_value_head_dim
self.conv_kernel_size = config.conv_kernel_size
key_dim = self.num_key_heads * self.key_head_dim
value_dim = self.num_value_heads * self.value_head_dim
self.key_dim = key_dim
self.value_dim = value_dim
conv_dim = key_dim * 2 + value_dim
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(self, x, past_key_value=None, **kwargs):
batch_size, seq_len, _ = x.shape
use_recurrent = (
past_key_value is not None
and past_key_value[2] > 0
and seq_len == 1
)
# Projections (shared)
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
z = self.in_proj_z(x)
b = self.in_proj_b(x)
a = self.in_proj_a(x)
# Conv1d
if use_recurrent:
recurrent_state, conv_state, step_index = past_key_value
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
else:
if past_key_value is not None:
recurrent_state, conv_state, step_index = past_key_value
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
# Split QKV and compute beta/g
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
# Delta rule
if use_recurrent:
# single-token path: work in [B, heads, dim] without seq dim
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=1)
key = key.repeat_interleave(rep, dim=1)
scale = self.key_head_dim ** -0.5
q = F.normalize(query.float(), dim=-1) * scale
k = F.normalize(key.float(), dim=-1)
v = value.float()
beta_t = beta.reshape(batch_size, -1)
g_t = g.reshape(batch_size, -1).exp()
# In-place state update: [B, heads, k_dim, v_dim]
recurrent_state.mul_(g_t[:, :, None, None])
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
delta = (v - kv_mem) * beta_t[:, :, None]
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
present_key_value = (recurrent_state, conv_state, step_index + 1)
else:
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=2)
key = key.repeat_interleave(rep, dim=2)
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
query, key, value, g=g, beta=beta,
initial_state=None,
output_final_state=past_key_value is not None,
)
present_key_value = None
if past_key_value is not None:
if last_recurrent_state is not None:
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
# Gated norm + output projection (shared)
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
return output, present_key_value
# GatedAttention - Full Attention with output gating
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
"""Compute RoPE frequencies for partial rotary embeddings."""
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if mrope_section is not None and position_ids.shape[0] == 3:
mrope_section_2 = [s * 2 for s in mrope_section]
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
"""Apply RoPE to only the first rotary_dim dimensions."""
xq_rot = xq[..., :rotary_dim]
xq_pass = xq[..., rotary_dim:]
xk_rot = xk[..., :rotary_dim]
xk_pass = xk[..., rotary_dim:]
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
xq = torch.cat([xq_rot, xq_pass], dim=-1)
xk = torch.cat([xk_rot, xk_pass], dim=-1)
return xq, xk
class GatedAttention(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.inner_size = self.num_heads * self.head_dim
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
# q_proj outputs 2x: query + gate
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
# QK norms with (1+weight) scaling
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
batch_size, seq_length, _ = x.shape
# Project Q (with gate), K, V
qg = self.q_proj(x)
# Split into query and gate: each is [B, seq, inner_size]
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
xk = self.k_proj(x)
xv = self.v_proj(x)
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply partial RoPE
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
# KV cache
present_key_value = None
if past_key_value is not None:
past_key, past_value, index = past_key_value
num_tokens = xk.shape[2]
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + num_tokens] = xk
past_value[:, :, index:index + num_tokens] = xv
xk = past_key[:, :, :index + num_tokens]
xv = past_value[:, :, :index + num_tokens]
present_key_value = (past_key, past_value, index + num_tokens)
else:
if index > 0:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
# Expand KV heads for GQA
if self.num_heads != self.num_kv_heads:
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
output = output * gate.sigmoid()
return self.o_proj(output), present_key_value
# Hybrid Transformer Block
class Qwen35TransformerBlock(nn.Module):
def __init__(self, config, index, device=None, dtype=None, ops=None):
super().__init__()
self.layer_type = config.layer_types[index]
if self.layer_type == "linear_attention":
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
else:
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
if self.layer_type == "linear_attention":
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
else:
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
x = x + h
x = x + self.mlp(self.post_attention_layernorm(x))
return x, present_key_value
# Qwen35 Transformer Backbone
class Qwen35Transformer(Llama2_):
def __init__(self, config, device=None, dtype=None, ops=None):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
for i, layer in enumerate(self.layers):
if layer.layer_type == "full_attention":
if len(past_key_values) > i:
return past_key_values[i][2]
break
return 0
def compute_freqs_cis(self, position_ids, device):
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
return precompute_partial_rope(
self.config.head_dim, rotary_dim, position_ids,
self.config.rope_theta, device=device,
mrope_section=self.config.mrope_section,
)
# Vision Encoder
class Qwen35VisionPatchEmbed(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = config["patch_size"]
self.temporal_patch_size = config["temporal_patch_size"]
self.in_channels = config["in_channels"]
self.embed_dim = config["hidden_size"]
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(self, hidden_state):
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
class Qwen35VisionRotaryEmbedding(nn.Module):
def __init__(self, dim, theta=10000.0):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen):
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen35VisionAttention(nn.Module):
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
super().__init__()
self.dim = hidden_size
self.num_heads = num_heads
self.head_dim = self.dim // self.num_heads
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
seq_length = x.shape[0]
query_states, key_states, value_states = (
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
# Process per-sequence attention
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_splits = torch.split(query_states, lengths, dim=0)
k_splits = torch.split(key_states, lengths, dim=0)
v_splits = torch.split(value_states, lengths, dim=0)
attn_outputs = []
for q, k, v in zip(q_splits, k_splits, v_splits):
q = q.transpose(0, 1).unsqueeze(0)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1)
return self.proj(attn_output)
class Qwen35VisionBlock(nn.Module):
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
return x + self.mlp(self.norm2(x))
class Qwen35VisionPatchMerger(nn.Module):
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
super().__init__()
merge_dim = hidden_size * (spatial_merge_size ** 2)
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
self.merge_dim = merge_dim
def forward(self, x):
x = self.norm(x).view(-1, self.merge_dim)
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
class Qwen35VisionModel(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.spatial_merge_size = config["spatial_merge_size"]
self.patch_size = config["patch_size"]
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_heads"]
self.num_position_embeddings = config["num_position_embeddings"]
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
self.blocks = nn.ModuleList([
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
for _ in range(config["depth"])
])
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
def rot_pos_emb(self, grid_thw):
merge_size = self.spatial_merge_size
grid_thw_list = grid_thw.tolist()
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw_list:
num_frames, height, width = int(num_frames), int(height), int(width)
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset:offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_thw_list = grid_thw.tolist()
grid_ts = [int(row[0]) for row in grid_thw_list]
grid_hs = [int(row[1]) for row in grid_thw_list]
grid_ws = [int(row[2]) for row in grid_thw_list]
device = self.pos_embed.weight.device
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in grid_thw_list:
h, w = int(h), int(w)
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for j in range(4):
idx_list[j].extend(indices[j].tolist())
weight_list[j].extend(weights[j].tolist())
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
patch_pos_embeds_permute = []
merge_size = self.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
def forward(self, x, grid_thw):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().unsqueeze(-2)
sin = emb.sin().unsqueeze(-2)
sin_half = sin.shape[-1] // 2
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
merged = self.merger(x)
return merged
# Model Wrapper
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
model_type = "qwen35_2b"
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = _make_config(self.model_type, config_dict)
self.num_layers = config.num_hidden_layers
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for i in range(model_config.num_hidden_layers):
if model_config.layer_types[i] == "linear_attention":
recurrent_state = torch.zeros(
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
device=device, dtype=torch.float32
)
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
conv_state = torch.zeros(
[batch, conv_dim, model_config.conv_kernel_size - 1],
device=device, dtype=execution_dtype
)
past_key_values.append((recurrent_state, conv_state, 0))
else:
past_key_values.append((
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
0
))
return past_key_values
# Tokenizer and Text Encoder Wrappers
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
from transformers import Qwen2Tokenizer
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image]
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
if not thinking:
llama_text += "<think>\n</think>\n"
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 248056: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
class Qwen35ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
class Qwen35_(Qwen35):
pass
Qwen35_.model_type = model_type
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen35TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
def tokenizer(model_type="qwen35_2b"):
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
return Qwen35ImageTokenizer_
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
class Qwen35TEModel_(Qwen35TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
return Qwen35TEModel_

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states)
# Potentially important for spatially precise edits. This is present in the HF implementation.
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states

View File

@@ -5,6 +5,10 @@ from comfy_api.latest._input import (
MaskInput,
LatentInput,
VideoInput,
CurvePoint,
CurveInput,
MonotoneCubicCurve,
LinearCurve,
)
__all__ = [
@@ -13,4 +17,8 @@ __all__ = [
"MaskInput",
"LatentInput",
"VideoInput",
"CurvePoint",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -1,4 +1,5 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .video_types import VideoInput
__all__ = [
@@ -7,4 +8,8 @@ __all__ = [
"VideoInput",
"MaskInput",
"LatentInput",
"CurvePoint",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -0,0 +1,219 @@
from __future__ import annotations
import logging
import math
from abc import ABC, abstractmethod
import numpy as np
logger = logging.getLogger(__name__)
CurvePoint = tuple[float, float]
class CurveInput(ABC):
"""Abstract base class for curve inputs.
Subclasses represent different curve representations (control-point
interpolation, analytical functions, LUT-based, etc.) while exposing a
uniform evaluation interface to downstream nodes.
"""
@property
@abstractmethod
def points(self) -> list[CurvePoint]:
"""The control points that define this curve."""
@abstractmethod
def interp(self, x: float) -> float:
"""Evaluate the curve at a single *x* value in [0, 1]."""
def interp_array(self, xs: np.ndarray) -> np.ndarray:
"""Vectorised evaluation over a numpy array of x values.
Subclasses should override this for better performance. The default
falls back to scalar ``interp`` calls.
"""
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
return self.interp_array(np.linspace(0.0, 1.0, size))
@staticmethod
def from_raw(data) -> CurveInput:
"""Convert raw curve data (dict or point list) to a CurveInput instance.
Accepts:
- A ``CurveInput`` instance (returned as-is).
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
"""
if isinstance(data, CurveInput):
return data
if isinstance(data, dict):
raw_points = data["points"]
interpolation = data.get("interpolation", "monotone_cubic")
else:
raw_points = data
interpolation = "monotone_cubic"
points = [(float(x), float(y)) for x, y in raw_points]
if interpolation == "linear":
return LinearCurve(points)
if interpolation != "monotone_cubic":
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
return MonotoneCubicCurve(points)
class MonotoneCubicCurve(CurveInput):
"""Monotone cubic Hermite interpolation over control points.
Mirrors the frontend ``createMonotoneInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
backend evaluation matches the editor preview exactly.
All heavy work (sorting, slope computation) happens once at construction.
``interp_array`` is fully vectorised with numpy.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
self._slopes = self._compute_slopes()
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def _compute_slopes(self) -> np.ndarray:
xs, ys = self._xs, self._ys
n = len(xs)
if n < 2:
return np.zeros(n, dtype=np.float64)
dx = np.diff(xs)
dy = np.diff(ys)
dx_safe = np.where(dx == 0, 1.0, dx)
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
slopes = np.empty(n, dtype=np.float64)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0.0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0.0
slopes[i + 1] = 0.0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha * alpha + beta * beta
if s > 9:
t = 3 / math.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
return slopes
def interp(self, x: float) -> float:
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
if x <= xs[0]:
return float(ys[0])
if x >= xs[-1]:
return float(ys[-1])
hi = int(np.searchsorted(xs, x, side='right'))
hi = min(hi, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
if dx == 0:
return float(ys[lo])
t = (x - xs[lo]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
"""Fully vectorised evaluation using numpy."""
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if n == 1:
return np.full_like(xs_in, ys[0], dtype=np.float64)
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
dx_safe = np.where(dx == 0, 1.0, dx)
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
result = np.where(xs_in <= xs[0], ys[0], result)
result = np.where(xs_in >= xs[-1], ys[-1], result)
return result
def __repr__(self) -> str:
return f"MonotoneCubicCurve(points={self._points})"
class LinearCurve(CurveInput):
"""Piecewise linear interpolation over control points.
Mirrors the frontend ``createLinearInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def interp(self, x: float) -> float:
xs, ys = self._xs, self._ys
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
return float(np.interp(x, xs, ys))
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
if len(self._xs) == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if len(self._xs) == 1:
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
return np.interp(xs_in, self._xs, self._ys)
def __repr__(self) -> str:
return f"LinearCurve(points={self._points})"

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from comfy.samplers import CFGGuider, Sampler
from comfy.sd import CLIP, VAE
from comfy.sd import StyleModel as StyleModel_
from comfy_api.input import VideoInput
from comfy_api.input import VideoInput, CurveInput as CurveInput_
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
@@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
from comfy_api.input import CurvePoint
if TYPE_CHECKING:
Type = CurveInput_
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
@@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
def as_dict(self):
d = super().as_dict()
if self.default is not None:
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
return d
@comfytype(io_type="HISTOGRAM")
class Histogram(ComfyTypeIO):
"""A histogram represented as a list of bin counts."""
Type = list[int]
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
@@ -1360,6 +1373,7 @@ class NodeInfoV1:
price_badge: dict | None = None
search_aliases: list[str]=None
essentials_category: str=None
has_intermediate_output: bool=None
@dataclass
@@ -1483,6 +1497,16 @@ class Schema:
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
has_intermediate_output: bool=False
"""Flags this node as having intermediate output that should persist across page refreshes.
Nodes with this flag behave like output nodes (their UI results are cached and resent
to the frontend) but do NOT automatically get added to the execution list. This means
they will only execute if they are on the dependency path of a real output node.
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
def validate(self):
'''Validate the schema:
@@ -1582,6 +1606,7 @@ class Schema:
category=self.category,
description=self.description,
output_node=self.is_output_node,
has_intermediate_output=self.has_intermediate_output,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
@@ -1873,6 +1898,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_HAS_INTERMEDIATE_OUTPUT = None
@final
@classproperty
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls.GET_SCHEMA()
return cls._HAS_INTERMEDIATE_OUTPUT
_INPUT_IS_LIST = None
@final
@classproperty
@@ -1965,6 +1998,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._NOT_IDEMPOTENT is None:
@@ -2240,5 +2275,6 @@ __all__ = [
"PriceBadge",
"BoundingBox",
"Curve",
"Histogram",
"NodeReplace",
]

View File

@@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
class VideoGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
image: InputUrlObject | None = Field(...)
image: InputUrlObject | None = Field(None)
reference_images: list[InputUrlObject] | None = Field(None)
duration: int = Field(...)
aspect_ratio: str | None = Field(...)
resolution: str = Field(...)
seed: int = Field(...)
class VideoExtensionRequest(BaseModel):
prompt: str = Field(...)
video: InputUrlObject = Field(...)
duration: int = Field(default=6)
model: str | None = Field(default=None)
class VideoEditRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)

226
comfy_api_nodes/apis/wan.py Normal file
View File

@@ -0,0 +1,226 @@
from pydantic import BaseModel, Field
class Text2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
audio_url: str | None = Field(None)
class Image2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
img_url: str = Field(...)
audio_url: str | None = Field(None)
class Reference2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
reference_video_urls: list[str] = Field(...)
class Txt2ImageParametersField(BaseModel):
size: str = Field(...)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True)
watermark: bool = Field(False)
class Image2ImageParametersField(BaseModel):
size: str | None = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(False)
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
class Image2VideoParametersField(BaseModel):
resolution: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
class Reference2VideoParametersField(BaseModel):
size: str = Field(...)
duration: int = Field(5, ge=5, le=15)
shot_type: str = Field("single")
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(False)
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2ImageInputField = Field(...)
parameters: Txt2ImageParametersField = Field(...)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2ImageInputField = Field(...)
parameters: Image2ImageParametersField = Field(...)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
parameters: Text2VideoParametersField = Field(...)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2VideoInputField = Field(...)
parameters: Image2VideoParametersField = Field(...)
class Reference2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Reference2VideoInputField = Field(...)
parameters: Reference2VideoParametersField = Field(...)
class Wan27MediaItem(BaseModel):
type: str = Field(...)
url: str = Field(...)
class Wan27ReferenceVideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
media: list[Wan27MediaItem] = Field(...)
class Wan27ReferenceVideoParametersField(BaseModel):
resolution: str = Field(...)
ratio: str | None = Field(None)
duration: int = Field(5, ge=2, le=10)
watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647)
class Wan27ReferenceVideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Wan27ReferenceVideoInputField = Field(...)
parameters: Wan27ReferenceVideoParametersField = Field(...)
class Wan27ImageToVideoInputField(BaseModel):
prompt: str | None = Field(None)
negative_prompt: str | None = Field(None)
media: list[Wan27MediaItem] = Field(...)
class Wan27ImageToVideoParametersField(BaseModel):
resolution: str = Field(...)
duration: int = Field(5, ge=2, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647)
class Wan27ImageToVideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Wan27ImageToVideoInputField = Field(...)
parameters: Wan27ImageToVideoParametersField = Field(...)
class Wan27VideoEditInputField(BaseModel):
prompt: str = Field(...)
media: list[Wan27MediaItem] = Field(...)
class Wan27VideoEditParametersField(BaseModel):
resolution: str = Field(...)
ratio: str | None = Field(None)
duration: int = Field(0)
audio_setting: str = Field("auto")
watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647)
class Wan27VideoEditTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Wan27VideoEditInputField = Field(...)
parameters: Wan27VideoEditParametersField = Field(...)
class Wan27Text2VideoParametersField(BaseModel):
resolution: str = Field(...)
ratio: str | None = Field(None)
duration: int = Field(5, ge=2, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647)
class Wan27Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
parameters: Wan27Text2VideoParametersField = Field(...)
class TaskCreationOutputField(BaseModel):
task_id: str = Field(...)
task_status: str = Field(...)
class TaskCreationResponse(BaseModel):
output: TaskCreationOutputField | None = Field(None)
request_id: str = Field(...)
code: str | None = Field(None, description="Error code for the failed request.")
message: str | None = Field(None, description="Details about the failed request.")
class TaskResult(BaseModel):
url: str | None = Field(None)
code: str | None = Field(None)
message: str | None = Field(None)
class ImageTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
results: list[TaskResult] | None = Field(None)
class VideoTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
video_url: str | None = Field(None)
code: str | None = Field(None)
message: str | None = Field(None)
class ImageTaskStatusResponse(BaseModel):
output: ImageTaskStatusOutputField | None = Field(None)
request_id: str = Field(...)
class VideoTaskStatusResponse(BaseModel):
output: VideoTaskStatusOutputField | None = Field(None)
request_id: str = Field(...)

View File

@@ -201,6 +201,16 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
if not thought:
# No images generated --> extract text response for a meaningful error
model_message = get_text_from_response(response).strip()
if model_message:
raise ValueError(f"Gemini did not generate an image. Model response: {model_message}")
raise ValueError(
"Gemini did not generate an image. "
"Try rephrasing your prompt or changing the response modality to 'IMAGE+TEXT' "
"to see the model's reasoning."
)
return torch.zeros((1, 1024, 1024, 4))
return torch.cat(image_tensors, dim=0)

View File

@@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
ImageGenerationResponse,
InputUrlObject,
VideoEditRequest,
VideoExtensionRequest,
VideoGenerationRequest,
VideoGenerationResponse,
VideoStatusResponse,
@@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
tensor_to_base64_string,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
@@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
return None
def _extract_grok_video_price(response) -> float | None:
price = _extract_grok_price(response)
if price is not None:
return price * 1.43
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
if model == "grok-imagine-video-beta":
model = "grok-imagine-video"
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
@@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="api node/video/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of the desired video.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="reference_",
min=1,
max=7,
),
tooltip="Up to 7 reference images to guide the video generation.",
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
tooltip="The aspect ratio of the output video.",
),
IO.Int.Input(
"duration",
default=6,
min=2,
max=10,
step=1,
tooltip="The duration of the output video in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model.duration", "model.resolution"],
input_groups=["model.reference_images"],
),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$refs := inputGroups["model.reference_images"];
$rate := $res = "720p" ? 0.07 : 0.05;
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
{"type":"usd","usd": $price}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
ref_image_urls = await upload_images_to_comfyapi(
cls,
list(model["reference_images"].values()),
mime_type="image/png",
wait_label="Uploading base images",
max_images=7,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model["model"],
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
prompt=prompt,
resolution=model["resolution"],
duration=model["duration"],
aspect_ratio=model["aspect_ratio"],
seed=seed,
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoExtendNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="api node/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of what should happen next in the video.",
),
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Int.Input(
"duration",
default=8,
min=2,
max=10,
step=1,
tooltip="Length of the extension in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video extension.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
expr="""
(
$dur := $lookup(widgets, "model.duration");
{
"type": "range_usd",
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
"max_usd": (0.15 + 0.05 * $dur) * 1.43
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
video: Input.Video,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=2, max_duration=15)
video_size = get_fs_object_size(video.get_stream_source())
if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
data=VideoExtensionRequest(
prompt=prompt,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
duration=model["duration"],
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
GrokImageNode,
GrokImageEditNode,
GrokVideoNode,
GrokVideoReferenceNode,
GrokVideoEditNode,
GrokVideoExtendNode,
]

View File

@@ -132,7 +132,7 @@ class TencentTextToModelNode(IO.ComfyNode):
tooltip="The LowPoly option is unavailable for the `3.1` model.",
),
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
IO.DynamicCombo.Input(
"generate_type",
options=[
@@ -251,7 +251,7 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.Image.Input("image_left", optional=True),
IO.Image.Input("image_right", optional=True),
IO.Image.Input("image_back", optional=True),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
IO.DynamicCombo.Input(
"generate_type",
options=[
@@ -422,6 +422,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(display_name="uv_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -468,9 +469,16 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
uv_image_file = get_file_from_response(result.ResultFile3Ds, "uv_image", raise_if_not_found=False)
uv_image = (
await download_url_to_image_tensor(uv_image_file.Url)
if uv_image_file is not None
else torch.zeros(1, 1, 1, 3)
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
uv_image,
)

View File

@@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
depends_on=IO.PriceBadgeDepends(
widgets=["upscale", "upscale.upscale_factor"],
),
expr="""
(
$factor := $lookup(widgets, "upscale.upscale_factor");
$fmt := {"approximate": true, "note": "(base)"};
widgets.upscale = "enabled" ? (
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
: {"type": "usd", "usd": 0.0457, "format": $fmt}
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
)
""",
),
)
@@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
widgets=["model", "upscale", "upscale.upscale_factor"],
),
expr="""
(
$fmt := {"approximate": true, "note": "(base)"};
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
$enabled := widgets.upscale = "enabled";
$factor := $lookup(widgets, "upscale.upscale_factor");
$isFast
? {"type": "usd", "usd": 0.01001, "format": $fmt}
: $enabled ? (
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
: {"type": "usd", "usd": 0.0686, "format": $fmt}
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
)
""",
),
@@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
widgets=["model", "upscale", "upscale.upscale_factor"],
),
expr="""
(
$fmt := {"approximate": true, "note": "(base)"};
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
$enabled := widgets.upscale = "enabled";
$factor := $lookup(widgets, "upscale.upscale_factor");
$isFast
? {"type": "usd", "usd": 0.01001, "format": $fmt}
: $enabled ? (
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
: {"type": "usd", "usd": 0.0686, "format": $fmt}
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
)
""",
),

View File

@@ -38,6 +38,7 @@ from comfy_api_nodes.util import (
UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@@ -475,6 +474,10 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
return await self._set_immediate(node_id, value)
def set_local(self, node_id, value):
self._mark_used(node_id)
BasicCache.set_local(self, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
await super()._ensure_subcache(node_id, children_ids)
@@ -489,15 +492,10 @@ class LRUCache(BasicCache):
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
#Small baseline weight used when a cache entry has no measurable CPU tensors.
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
@@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
def set_local(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set_local(node_id, value)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
def ram_release(self, target):
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
ram_usage += output.numel() * output.element_size()
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@@ -80,7 +80,7 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
@classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "type": "audio"})
@@ -103,7 +103,7 @@ class EmptyAceStep15LatentAudio(io.ComfyNode):
@classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = round((seconds * 48000 / 1920))
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceAudio(io.ComfyNode):

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import numpy as np
from comfy_api.latest import ComfyExtension, io
from comfy_api.input import CurveInput
from typing_extensions import override
class CurveEditor(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CurveEditor",
display_name="Curve Editor",
category="utils",
inputs=[
io.Curve.Input("curve"),
io.Histogram.Input("histogram", optional=True),
],
outputs=[
io.Curve.Output("curve"),
],
)
@classmethod
def execute(cls, curve, histogram=None) -> io.NodeOutput:
result = CurveInput.from_raw(curve)
ui = {}
if histogram is not None:
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
class ImageHistogram(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageHistogram",
display_name="Image Histogram",
category="utils",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Histogram.Output("rgb"),
io.Histogram.Output("luminance"),
io.Histogram.Output("red"),
io.Histogram.Output("green"),
io.Histogram.Output("blue"),
],
)
@classmethod
def execute(cls, image) -> io.NodeOutput:
img = image[0].cpu().numpy()
img_uint8 = np.clip(img * 255, 0, 255).astype(np.uint8)
def bincount(data):
return np.bincount(data.ravel(), minlength=256)[:256]
hist_r = bincount(img_uint8[:, :, 0])
hist_g = bincount(img_uint8[:, :, 1])
hist_b = bincount(img_uint8[:, :, 2])
# Average of R, G, B histograms (same as Photoshop's RGB composite)
rgb = ((hist_r + hist_g + hist_b) // 3).tolist()
# ITU-R BT.709-6, Item 3.2 (p.6) — Derivation of luminance signal
# https://www.itu.int/rec/R-REC-BT.709-6-201506-I/en
lum = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2]
luminance = bincount(np.clip(lum * 255, 0, 255).astype(np.uint8)).tolist()
return io.NodeOutput(
rgb,
luminance,
hist_r.tolist(),
hist_g.tolist(),
hist_b.tolist(),
)
class CurveExtension(ComfyExtension):
@override
async def get_node_list(self):
return [CurveEditor, ImageHistogram]
async def comfy_entrypoint():
return CurveExtension()

View File

@@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
MAX_BOOLS = 10 # u_bool0-9
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed.
@@ -497,6 +499,8 @@ def _render_shader_batch(
image_batches: list[list[np.ndarray]],
floats: list[float],
ints: list[int],
bools: list[bool] | None = None,
curves: list[np.ndarray] | None = None,
) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
@@ -511,6 +515,8 @@ def _render_shader_batch(
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
@@ -533,11 +539,17 @@ def _render_shader_batch(
# Detect multi-pass rendering
num_passes = _detect_pass_count(fragment_code)
if bools is None:
bools = []
if curves is None:
curves = []
# Track resources for cleanup
program = None
fbo = None
output_textures = []
input_textures = []
curve_textures = []
ping_pong_textures = []
ping_pong_fbos = []
@@ -624,6 +636,28 @@ def _render_shader_batch(
if loc >= 0:
gl.glUniform1i(loc, v)
for i, v in enumerate(bools):
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
if loc >= 0:
gl.glUniform1i(loc, 1 if v else 0)
# Create 1D LUT textures for curves (bound after image texture units)
for i, lut in enumerate(curves):
tex = gl.glGenTextures(1)
curve_textures.append(tex)
unit = MAX_IMAGES + i
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
if loc >= 0:
gl.glUniform1i(loc, unit)
# Get u_pass uniform location for multi-pass
pass_loc = gl.glGetUniformLocation(program, "u_pass")
@@ -718,6 +752,8 @@ def _render_shader_batch(
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
@@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
max=MAX_UNIFORMS,
)
bool_template = io.Autogrow.TemplatePrefix(
io.Boolean.Input("bool", default=False),
prefix="u_bool",
min=0,
max=MAX_BOOLS,
)
curve_template = io.Autogrow.TemplatePrefix(
io.Curve.Input("curve"),
prefix="u_curve",
min=0,
max=MAX_CURVES,
)
return io.Schema(
node_id="GLSLShader",
display_name="GLSL Shader",
@@ -762,6 +812,8 @@ class GLSLShader(io.ComfyNode):
"Apply GLSL ES fragment shaders to images. "
"u_resolution (vec2) is always available."
),
is_experimental=True,
has_intermediate_output=True,
inputs=[
io.String.Input(
"fragment_shader",
@@ -796,6 +848,8 @@ class GLSLShader(io.ComfyNode):
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
],
outputs=[
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
@@ -813,13 +867,19 @@ class GLSLShader(io.ComfyNode):
images: io.Autogrow.Type,
floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None,
bools: io.Autogrow.Type = None,
curves: io.Autogrow.Type = None,
**kwargs,
) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None]
float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
if not image_list:
raise ValueError("At least one input image is required")
@@ -846,6 +906,8 @@ class GLSLShader(io.ComfyNode):
image_batches,
float_list,
int_list,
bool_list,
curve_luts,
)
# Collect outputs into tensors

View File

@@ -59,6 +59,7 @@ class ImageCropV2(IO.ComfyNode):
display_name="Image Crop",
category="image/transform",
essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.samplers
import comfy.utils
import math
import numpy as np
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
return io.NodeOutput(video_latent, audio_latent)
class LTXVReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVReferenceAudio",
display_name="LTXV Reference Audio (ID-LoRA)",
category="conditioning/audio",
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
],
outputs=[
io.Model.Output(),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio)
b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens}
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
# Patch model with identity guidance
m = model.clone()
scale = identity_guidance_scale
model_sampling = m.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args):
if scale == 0:
return args["denoised"]
sigma = args["sigma"]
sigma_ = sigma[0].item()
if sigma_ > sigma_start or sigma_ < sigma_end:
return args["denoised"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
model_options = args["model_options"].copy()
x = args["input"]
# Strip ref_audio from conditioning for the no-reference pass
noref_cond = []
for entry in cond:
new_entry = entry.copy()
mc = new_entry.get("model_conds", {}).copy()
mc.pop("ref_audio", None)
new_entry["model_conds"] = mc
noref_cond.append(new_entry)
(pred_noref,) = comfy.samplers.calc_cond_batch(
args["model"], [noref_cond], x, sigma, model_options
)
return cfg_result + (cond_pred - pred_noref) * scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
return io.NodeOutput(m, positive, negative)
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
LTXVReferenceAudio,
]

Some files were not shown because too many files have changed in this diff Show More