Compare commits

..

125 Commits

Author SHA1 Message Date
Jedrzej Kosinski
48deb15c0e Simplify multigpu dispatch: run all devices on pool threads (#13340)
Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090
with SD1.5 and NetaYume models. No meaningful performance difference
(within noise). All-pool is simpler: eliminates the main_device
special case, main_batch_tuple deferred execution, and the 3-way
branch in the dispatch loop.
2026-04-09 01:15:57 -07:00
Jedrzej Kosinski
4b93c4360f Implement persistent thread pool for multi-GPU CFG splitting (#13329)
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a
persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device()
once at startup, preserving compiled kernel caches across diffusion steps.

- Add MultiGPUThreadPool class in comfy/multigpu.py
- Create pool in CFGGuider.outer_sample(), shut down in finally block
- Main thread handles its own device batch directly for zero overhead
- Falls back to sequential execution if no pool is available
2026-04-08 05:39:07 -07:00
Jedrzej Kosinski
da3864436c Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2026-04-08 05:08:38 -07:00
Jedrzej Kosinski
b418fb1582 Fix device mismatch: update LoadedModel.device when _switch_parent swaps to parent patcher
When a multigpu clone ModelPatcher is garbage collected, LoadedModel._switch_parent
switches the weakref to point at the parent (main) ModelPatcher. However, it was not
updating LoadedModel.device, leaving it with the old clone's device (e.g., cuda:1).
On subsequent runs, this stale device was passed to ModelPatcherDynamic.load(), causing
an assertion failure (device_to != self.load_device).

Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:59:38 -07:00
Jedrzej Kosinski
20803749c3 Add detailed multigpu debug logging to load_models_gpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:53:36 -07:00
Jedrzej Kosinski
3fab720be9 Add debug logging for device mismatch in ModelPatcherDynamic.load
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:45:55 -07:00
Jedrzej Kosinski
afdddcee66 Re-enable comfy-kitchen cuda backend for multigpu testing
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:32:52 -07:00
Jedrzej Kosinski
1d8e379f41 Rename MultiGPU Work Units to MultiGPU CFG Split
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:00:20 -07:00
Jedrzej Kosinski
5f4fcd19e7 Simplify multigpu nodes: default max_gpus=2, remove gpu_options input, disable Options node
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:30:32 -07:00
Jedrzej Kosinski
d52dcbc88f Rewrite multigpu nodes to V3 format
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:23:13 -07:00
Jedrzej Kosinski
84f465e791 Set CUDA device at start of multigpu threads to avoid multithreading bugs
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:07:54 -07:00
Jedrzej Kosinski
be35378986 Merge branch 'master' into worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>

# Conflicts:
#	comfy/samplers.py
2026-03-30 06:24:55 -07:00
Jedrzej Kosinski
f410d28b33 Merge origin/master into worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d009d-e059-7623-85ca-401168168516
Co-authored-by: Amp <amp@ampcode.com>
2026-03-18 04:21:30 -07:00
Jedrzej Kosinski
f4b99bc623 Made multigpu deepclone load model from disk to avoid needing to deepclone actual model object, fixed issues with merge, turn off cuda backend as it causes device mismatch issue with rope (and potentially other ops), will investigate 2026-02-17 04:55:00 -08:00
Jedrzej Kosinski
df2fd4c869 Merge branch 'master' into worksplit-multigpu 2026-02-17 02:53:06 -08:00
Jedrzej Kosinski
4661d1db5a Bring patches changes from _calc_cond_batch into _calc_cond_batch_multigpu 2025-10-15 17:34:36 -07:00
Jedrzej Kosinski
b326a544d5 Merge branch 'master' into worksplit-multigpu 2025-10-15 17:33:02 -07:00
Jedrzej Kosinski
d89dd5f0b0 Satisfy ruff 2025-10-13 22:00:34 -07:00
Jedrzej Kosinski
8cbbf0be6c Merge branch 'master' into worksplit-multigpu 2025-10-13 21:53:14 -07:00
Jedrzej Kosinski
c2115a4bac Merge branch 'master' into worksplit-multigpu 2025-09-24 23:45:26 -07:00
Jedrzej Kosinski
bb44c2ecb9 Merge branch 'master' into worksplit-multigpu 2025-09-18 14:20:27 -07:00
Jedrzej Kosinski
efcd8280d6 Merge branch 'master' into worksplit-multigpu 2025-09-11 20:59:47 -07:00
Jedrzej Kosinski
9e9c129cd0 Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2025-08-29 23:36:19 -07:00
Jedrzej Kosinski
ac14ee68c0 Merge branch 'master' into worksplit-multigpu 2025-08-18 19:51:24 -07:00
Jedrzej Kosinski
2c8f485434 Merge branch 'master' into worksplit-multigpu 2025-08-18 00:29:52 -07:00
Jedrzej Kosinski
383f9b34cb Merge branch 'master' into worksplit-multigpu 2025-08-17 16:02:44 -07:00
Jedrzej Kosinski
b0741c7e5b Merge branch 'master' into worksplit-multigpu 2025-08-15 16:50:04 -07:00
Jedrzej Kosinski
1489399cb5 Merge branch 'master' into worksplit-multigpu 2025-08-13 19:47:08 -07:00
Jedrzej Kosinski
3677943fa5 Merge branch 'master' into worksplit-multigpu 2025-08-13 14:06:09 -07:00
Jedrzej Kosinski
cfb63bfcd7 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-08-11 14:09:58 -07:00
Jedrzej Kosinski
962c3c832c Merge branch 'master' into worksplit-multigpu 2025-08-11 14:09:41 -07:00
Jedrzej Kosinski
6ea69369ce Merge branch 'master' into worksplit-multigpu 2025-08-07 23:24:02 -07:00
Jedrzej Kosinski
b4f559b34d Merge branch 'master' into worksplit-multigpu 2025-08-04 20:23:19 -07:00
Jedrzej Kosinski
df122a7dba Merge branch 'master' into worksplit-multigpu 2025-08-01 12:31:57 -07:00
Jedrzej Kosinski
67e906aa64 Merge branch 'master' into worksplit-multigpu 2025-07-31 04:00:22 -07:00
Jedrzej Kosinski
382f84a826 Merge branch 'master' into worksplit-multigpu 2025-07-29 17:17:29 -07:00
Jedrzej Kosinski
9cca36fa2b Merge branch 'master' into worksplit-multigpu 2025-07-29 12:47:36 -07:00
Jedrzej Kosinski
5d5024296d Merge branch 'master' into worksplit-multigpu 2025-07-28 06:17:24 -07:00
Jedrzej Kosinski
3b90a30178 Merge branch 'master' into worksplit-multigpu-wip 2025-07-27 01:03:25 -07:00
Jedrzej Kosinski
3c4104652b Merge branch 'master' into worksplit-multigpu-wip 2025-07-22 11:42:23 -07:00
kosinkadink1@gmail.com
9855baaab3 Merge branch 'master' into worksplit-multigpu 2025-07-09 03:57:30 -05:00
Jedrzej Kosinski
d53479a197 Merge branch 'master' into worksplit-multigpu 2025-07-01 17:33:05 -05:00
Jedrzej Kosinski
443a795850 Merge branch 'master' into worksplit-multigpu 2025-06-24 00:49:24 -05:00
Jedrzej Kosinski
431dec8e53 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-06-24 00:48:58 -05:00
Jedrzej Kosinski
44e053c26d Improve error handling for multigpu threads 2025-06-24 00:48:51 -05:00
Jedrzej Kosinski
1ae98932f1 Merge branch 'master' into worksplit-multigpu 2025-06-17 04:58:56 -05:00
kosinkadink1@gmail.com
0336b0ace8 Merge branch 'master' into worksplit-multigpu 2025-06-01 02:39:26 -07:00
kosinkadink1@gmail.com
8ae25235ec Merge branch 'master' into worksplit-multigpu 2025-05-21 12:01:27 -07:00
Jedrzej Kosinski
9726eac475 Merge branch 'master' into worksplit-multigpu 2025-05-12 19:29:13 -05:00
Jedrzej Kosinski
272e8d42c1 Merge branch 'master' into worksplit-multigpu 2025-04-22 22:40:00 -05:00
Jedrzej Kosinski
6211d2be5a Merge branch 'master' into worksplit-multigpu 2025-04-19 17:36:23 -05:00
Jedrzej Kosinski
8be711715c Make unload_all_models account for all devices 2025-04-19 17:35:54 -05:00
Jedrzej Kosinski
b5cccf1325 Merge branch 'master' into worksplit-multigpu 2025-04-18 15:39:34 -05:00
Jedrzej Kosinski
2a54a904f4 Merge branch 'master' into worksplit-multigpu 2025-04-16 19:26:48 -05:00
Jedrzej Kosinski
ed6f92c975 Merge branch 'master' into worksplit-multigpu 2025-04-16 16:53:57 -05:00
Jedrzej Kosinski
adc66c0698 Merge branch 'master' into worksplit-multigpu 2025-04-16 14:23:56 -05:00
Jedrzej Kosinski
ccd5c01e5a Merge branch 'master' into worksplit-multigpu 2025-04-09 09:17:12 -05:00
Jedrzej Kosinski
2fa9affcc1 Merge branch 'master' into worksplit-multigpu 2025-04-08 22:52:17 -05:00
Jedrzej Kosinski
407a5a656f Rollback core of last commit due to weird behavior 2025-03-28 02:48:11 -05:00
kosinkadink1@gmail.com
9ce9ff8ef8 Allow chained MultiGPU Work Unit nodes to affect max_gpus present on ModelPatcher clone 2025-03-28 15:29:44 +08:00
Jedrzej Kosinski
63567c0ce8 Merge branch 'master' into worksplit-multigpu 2025-03-27 22:36:46 -05:00
Jedrzej Kosinski
a786ce5ead Merge branch 'master' into worksplit-multigpu 2025-03-26 22:26:26 -05:00
Jedrzej Kosinski
4879b47648 Merge branch 'master' into worksplit-multigpu 2025-03-18 22:19:32 -05:00
Jedrzej Kosinski
5ccec33c22 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-03-17 14:27:39 -05:00
Jedrzej Kosinski
219d3cd0d0 Merge branch 'master' into worksplit-multigpu 2025-03-17 14:26:35 -05:00
Jedrzej Kosinski
c4ba399475 Merge branch 'master' into worksplit-multigpu 2025-03-15 09:12:09 -05:00
Jedrzej Kosinski
cc928a786d Merge branch 'master' into worksplit-multigpu 2025-03-13 20:59:11 -05:00
Jedrzej Kosinski
6e144b98c4 Merge branch 'master' into worksplit-multigpu 2025-03-09 00:00:38 -06:00
Jedrzej Kosinski
6dca17bd2d Satisfy ruff linting 2025-03-03 23:08:29 -06:00
Jedrzej Kosinski
5080105c23 Merge branch 'master' into worksplit-multigpu 2025-03-03 22:56:53 -06:00
Jedrzej Kosinski
093914a247 Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code 2025-03-03 22:56:13 -06:00
Jedrzej Kosinski
605893d3cf Merge branch 'master' into worksplit-multigpu 2025-02-24 19:23:16 -06:00
Jedrzej Kosinski
048f4f0b3a Merge branch 'master' into worksplit-multigpu 2025-02-17 19:35:58 -06:00
Jedrzej Kosinski
d2504fb701 Merge branch 'master' into worksplit-multigpu 2025-02-11 22:34:51 -06:00
Jedrzej Kosinski
b03763bca6 Merge branch 'multigpu_support' into worksplit-multigpu 2025-02-07 13:27:49 -06:00
Jedrzej Kosinski
476aa79b64 Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups 2025-02-06 08:44:07 -06:00
Jedrzej Kosinski
441cfd1a7a Merge branch 'master' into multigpu_support 2025-02-06 08:10:48 -06:00
Jedrzej Kosinski
99a5c1068a Merge branch 'master' into multigpu_support 2025-02-02 03:19:18 -06:00
Jedrzej Kosinski
02747cde7d Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu 2025-01-29 11:10:23 -06:00
Jedrzej Kosinski
0b3233b4e2 Merge remote-tracking branch 'origin/master' into multigpu_support 2025-01-28 06:11:07 -06:00
Jedrzej Kosinski
eda866bf51 Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet 2025-01-27 06:25:48 -06:00
Jedrzej Kosinski
e3298b84de Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support 2025-01-26 09:34:20 -06:00
Jedrzej Kosinski
c7feef9060 Cast transformer_options for multigpu 2025-01-26 05:29:27 -06:00
Jedrzej Kosinski
51af7fa1b4 Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets 2025-01-25 06:05:01 -06:00
Jedrzej Kosinski
46969c380a Initial MultiGPU support for controlnets 2025-01-24 05:39:38 -06:00
Jedrzej Kosinski
5db4277449 Make sure additional_models are unloaded as well when perform 2025-01-23 19:06:05 -06:00
Jedrzej Kosinski
02a4d0ad7d Added unload_model_and_clones to model_management.py to allow unloading only relevant models 2025-01-23 01:20:00 -06:00
Jedrzej Kosinski
ef137ac0b6 Merge branch 'multigpu_support' of https://github.com/kosinkadink/ComfyUI into multigpu_support 2025-01-20 04:34:39 -06:00
Jedrzej Kosinski
328d4f16a9 Make WeightHooks compatible with MultiGPU, clean up some code 2025-01-20 04:34:26 -06:00
Jedrzej Kosinski
bdbcb85b8d Merge branch 'multigpu_support' of https://github.com/Kosinkadink/ComfyUI into multigpu_support 2025-01-20 00:51:42 -06:00
Jedrzej Kosinski
6c9e94bae7 Merge branch 'master' into multigpu_support 2025-01-20 00:51:37 -06:00
Jedrzej Kosinski
bfce723311 Initial work on multigpu_clone function, which will account for additional_models getting cloned 2025-01-17 03:31:28 -06:00
Jedrzej Kosinski
31f5458938 Merge branch 'master' into multigpu_support 2025-01-16 18:25:05 -06:00
Jedrzej Kosinski
2145a202eb Merge branch 'master' into multigpu_support 2025-01-15 19:58:28 -06:00
Jedrzej Kosinski
25818dc848 Added a 'max_gpus' input 2025-01-14 13:45:14 -06:00
Jedrzej Kosinski
198953cd08 Add nodes_multigpu.py to loaded nodes 2025-01-14 12:24:55 -06:00
Jedrzej Kosinski
ec16ee2f39 Merge branch 'master' into multigpu_support 2025-01-13 20:21:06 -06:00
Jedrzej Kosinski
d5088072fb Make test node for multigpu instead of storing it in just a local __init__.py 2025-01-13 20:20:25 -06:00
Jedrzej Kosinski
8d4b50158e Merge branch 'master' into multigpu_support 2025-01-11 20:16:42 -06:00
Jedrzej Kosinski
e88c6c03ff Fix cond_cat to not try to cast anything that doesn't have a 'to' function 2025-01-10 23:05:24 -06:00
Jedrzej Kosinski
d3cf2b7b24 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-10 20:24:37 -06:00
Jedrzej Kosinski
7448f02b7c Initial proof of concept of giving splitting cond sampling between multiple GPUs 2025-01-08 03:33:05 -06:00
Jedrzej Kosinski
871258aa72 Add get_all_torch_devices to get detected devices intended for current torch hardware device 2025-01-07 21:06:03 -06:00
Jedrzej Kosinski
66838ebd39 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-07 20:11:27 -06:00
Jedrzej Kosinski
7333281698 Clean up a typehint 2025-01-07 02:58:59 -06:00
Jedrzej Kosinski
3cd4c5cb0a Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) 2025-01-07 02:22:49 -06:00
Jedrzej Kosinski
11c6d56037 Merge branch 'master' into hooks_part2 2025-01-07 01:01:53 -06:00
Jedrzej Kosinski
216fea15ee Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable 2025-01-07 00:59:18 -06:00
Jedrzej Kosinski
58bf8815c8 Add a get_injections function to ModelPatcher 2025-01-06 20:34:30 -06:00
Jedrzej Kosinski
1b38f5bf57 removed 4 whitespace lines to satisfy Ruff, 2025-01-06 17:11:12 -06:00
Jedrzej Kosinski
2724ac4a60 Merge branch 'master' into hooks_part2 2025-01-06 17:04:24 -06:00
Jedrzej Kosinski
f48f90e471 Make hook_scope functional for TransformerOptionsHook 2025-01-06 02:23:04 -06:00
Jedrzej Kosinski
6463c39ce0 Merge branch 'master' into hooks_part2 2025-01-06 01:28:26 -06:00
Jedrzej Kosinski
0a7e2ae787 Filter only registered hooks on self.conds in CFGGuider.sample 2025-01-06 01:04:29 -06:00
Jedrzej Kosinski
03a97b604a Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) 2025-01-06 01:03:59 -06:00
Jedrzej Kosinski
4446c86052 Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational 2025-01-05 22:25:51 -06:00
Jedrzej Kosinski
8270ff312f Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time 2025-01-05 21:07:02 -06:00
Jedrzej Kosinski
db2d7ad9ba Merge branch 'add_sample_sigmas' into hooks_part2 2025-01-05 15:45:13 -06:00
Jedrzej Kosinski
6620d86318 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch 2025-01-05 15:26:22 -06:00
Jedrzej Kosinski
111fd0cadf Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type 2025-01-04 02:04:07 -06:00
Jedrzej Kosinski
776aa734e1 Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) 2025-01-04 01:02:21 -06:00
Jedrzej Kosinski
5a2ad032cb Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed 2025-01-03 20:02:27 -06:00
Jedrzej Kosinski
d44295ef71 Merge branch 'master' into hooks_part2 2025-01-03 18:28:31 -06:00
Jedrzej Kosinski
bf21be066f Merge branch 'master' into hooks_part2 2024-12-30 14:16:22 -06:00
Jedrzej Kosinski
72bbf49349 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' 2024-12-29 15:49:09 -06:00
37 changed files with 963 additions and 1477 deletions

View File

@@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@@ -139,9 +139,9 @@ Example:
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
@@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@@ -182,7 +182,7 @@
]
},
"widgets_values": [
0
50
]
},
{

View File

@@ -316,7 +316,7 @@
"step": 1
},
"widgets_values": [
0
30
]
},
{

View File

@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet = None
self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -906,6 +949,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -1,303 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
if not comfy.model_management.supports_fp64(pos.device):
device = torch.device("cpu")
else:
device = pos.device
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = freqs_cis[0]
sin_ = freqs_cis[1]
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: tuple):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = list(axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
super().__init__()
self.patch_size = patch_size
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
batch_size, dim, height, width = x.shape
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool = False):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
if self.flip_sin_to_cos:
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
else:
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None):
super().__init__()
Linear = operations.Linear
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU()
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class ErnieImageAttention(nn.Module):
def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
super().__init__()
self.heads = heads
self.head_dim = dim_head
self.inner_dim = heads * dim_head
Linear = operations.Linear
RMSNorm = operations.RMSNorm
self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)])
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor:
B, S, _ = x.shape
q_flat = self.to_q(x)
k_flat = self.to_k(x)
v_flat = self.to_v(x)
query = q_flat.view(B, S, self.heads, self.head_dim)
key = k_flat.view(B, S, self.heads, self.head_dim)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query, key = query.to(x.dtype), key.to(x.dtype)
q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)
hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask)
return self.to_out[0](hidden_states)
class ErnieImageFeedForward(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None):
super().__init__()
Linear = operations.Linear
self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
class ErnieImageSharedAdaLNBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
super().__init__()
RMSNorm = operations.RMSNorm
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
self.self_attention = ErnieImageAttention(
query_dim=hidden_size,
dim_head=hidden_size // num_heads,
heads=num_heads,
eps=eps,
operations=operations,
device=device,
dtype=dtype
)
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype)
def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
residual = x
x_norm = self.adaLN_sa_ln(x)
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
residual = x
x_norm = self.adaLN_mlp_ln(x)
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
super().__init__()
LayerNorm = operations.LayerNorm
Linear = operations.Linear
self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x
class ErnieImageModel(nn.Module):
def __init__(
self,
hidden_size: int = 4096,
num_attention_heads: int = 32,
num_layers: int = 36,
ffn_hidden_size: int = 12288,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
text_in_dim: int = 3072,
rope_theta: int = 256,
rope_axes_dim: tuple = (32, 48, 48),
eps: float = 1e-6,
qk_layernorm: bool = True,
device=None,
dtype=None,
operations=None,
**kwargs
):
super().__init__()
self.dtype = dtype
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.patch_size = patch_size
self.out_channels = out_channels
Linear = operations.Linear
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype)
self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False)
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype)
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype)
)
self.layers = nn.ModuleList([
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype)
for _ in range(num_layers)
])
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype)
self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype)
def forward(self, x, timesteps, context, **kwargs):
device, dtype = x.device, x.dtype
B, C, H, W = x.shape
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
N_img = Hp * Wp
img_bsh = self.x_embedder(x)
text_bth = context
if self.text_proj is not None and text_bth.numel() > 0:
text_bth = self.text_proj(text_bth)
Tmax = text_bth.shape[1]
hidden_states = torch.cat([img_bsh, text_bth], dim=1)
text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32)
text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32)
index = float(Tmax)
transformer_options = kwargs.get("transformer_options", {})
rope_options = transformer_options.get("rope_options", None)
h_len, w_len = float(Hp), float(Wp)
h_offset, w_offset = 0.0, 0.0
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32)
image_ids[:, :, 0] = image_ids[:, :, 1] + index
image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1)
image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0)
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
del image_ids, text_ids
sample = self.time_proj(timesteps).to(dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
]
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
for layer in self.layers:
hidden_states = layer(hidden_states, rotary_pos_emb, temb)
hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states)
patches = self.final_linear(hidden_states)[:, :N_img, :]
output = (
patches.view(B, Hp, Wp, p, p, self.out_channels)
.permute(0, 5, 1, 3, 2, 4)
.contiguous()
.view(B, self.out_channels, H, W)
)
return output

View File

@@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if not comfy.model_management.supports_fp64(pos.device):
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
device = torch.device("cpu")
else:
device = pos.device

View File

@@ -90,7 +90,7 @@ class HeatmapHead(torch.nn.Module):
origin_max = np.max(hm[k])
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
dr[border:-border, border:-border] = hm[k].copy()
dr = gaussian_filter(dr, sigma=2.0, truncate=2.5)
dr = gaussian_filter(dr, sigma=2.0)
hm[k] = dr[border:-border, border:-border].copy()
cur_max = np.max(hm[k])
if cur_max > 0:

View File

@@ -53,7 +53,6 @@ import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.model_management
import comfy.patcher_extension
@@ -1963,14 +1962,3 @@ class Kandinsky5Image(Kandinsky5):
class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
class ErnieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -713,11 +713,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
dit_config = {}
dit_config["image_model"] = "ernie"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@@ -32,6 +33,11 @@ import comfy.memory_management
import comfy.utils
import comfy.quant_ops
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -206,6 +212,25 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -494,9 +519,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models = []
current_loaded_models: list[LoadedModel] = []
def module_size(module):
module_mem = 0
@@ -529,7 +558,7 @@ def module_mmap_residency(module, free=False):
return mmap_touched_mem, module_mem
class LoadedModel:
def __init__(self, model):
def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -537,7 +566,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -548,6 +577,7 @@ class LoadedModel:
model = self._parent_model()
if model is not None:
self._set_model(model)
self.device = model.load_device
@property
def model(self):
@@ -1732,21 +1762,6 @@ def supports_mxfp8_compute(device=None):
return True
def supports_fp64(device=None):
if is_device_mps(device):
return False
if is_intel_xpu():
return False
if is_directml_enabled():
return False
if is_ixuca():
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
@@ -1794,7 +1809,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
def debug_memory_summary():
if is_amd() or is_nvidia():

View File

@@ -23,6 +23,7 @@ import inspect
import logging
import math
import uuid
import copy
from typing import Callable, Optional
import torch
@@ -75,12 +76,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -272,7 +276,10 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -326,6 +333,8 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@@ -384,19 +393,98 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
if self.cached_patcher_init is not None:
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
n.model = temp_model_patcher.model
else:
n.model = copy.deepcopy(n.model)
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@@ -1167,7 +1255,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -1221,9 +1309,13 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
callback(self, timestep, model_options, ignore_multigpu)
if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options, ignore_multigpu=True)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -1236,12 +1328,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1253,6 +1351,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):

230
comfy/multigpu.py Normal file
View File

@@ -0,0 +1,230 @@
from __future__ import annotations
import queue
import threading
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}
for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))
def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))
def get_result(self, device: torch.device):
return self._result_queues[device].get()
@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())
def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
# multigpu_models = model.get_additional_models_with_key("multigpu")
# new_multigpu_models = []
# for m in multigpu_models:
# if m.load_device in limit_extra_devices:
# new_multigpu_models.append(m)
# model.set_additional_models("multigpu", new_multigpu_models)
# persist skip_devices for use in sampling code
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
# model.model_options["multigpu_skip_devices"] = skip_devices
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():

View File

@@ -3,6 +3,8 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

View File

@@ -20,7 +20,6 @@ try:
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")

View File

@@ -1,16 +1,18 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@@ -118,6 +120,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -142,7 +185,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -154,7 +198,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
real_model: BaseModel = model.model
return real_model, conds, models
@@ -200,3 +244,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
@@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@@ -212,7 +217,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -244,7 +251,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -345,6 +352,212 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = comfy.model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep.to(device)
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
# Submit all GPU work to pool threads
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -649,6 +862,8 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -891,7 +1106,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -900,18 +1117,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
if "patches" in transformer_options:
patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -921,8 +1137,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -930,7 +1146,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -985,16 +1200,31 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
# Create persistent thread pool for all GPU devices (main + extras)
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@@ -62,7 +62,6 @@ import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.model_patcher
import comfy.lora
@@ -1236,7 +1235,6 @@ class TEModel(Enum):
QWEN35_4B = 25
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
def detect_te_model(sd):
@@ -1303,8 +1301,6 @@ def detect_te_model(sd):
return TEModel.MISTRAL3_24B
else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
if weight.shape[0] == 3072:
return TEModel.MINISTRAL_3_3B
return TEModel.LLAMA3_8
return None
@@ -1462,10 +1458,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
elif te_model == TEModel.MINISTRAL_3_3B:
clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1604,10 +1596,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):

View File

@@ -26,7 +26,6 @@ import comfy.text_encoders.z_image
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
from . import supported_models_base
from . import latent_formats
@@ -1750,37 +1749,6 @@ class RT_DETR_v4(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
class ErnieImage(supported_models_base.BASE):
unet_config = {
"image_model": "ernie",
}
sampling_settings = {
"multiplier": 1000.0,
"shift": 3.0,
}
memory_usage_factor = 10.0
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.ErnieImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
models += [SVD_img2vid]

View File

@@ -1,38 +0,0 @@
from .flux import Mistral3Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
class Ministral3_3BTokenizer(Mistral3Tokenizer):
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}):
return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data)
class ErnieTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Ministral3_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ErnieTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
class ErnieTEModel_(ErnieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ErnieTEModel

View File

@@ -116,9 +116,9 @@ class MistralTokenizerClass:
return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}

View File

@@ -60,30 +60,6 @@ class Mistral3Small24BConfig:
final_norm: bool = True
lm_head: bool = False
@dataclass
class Ministral3_3BConfig:
vocab_size: int = 131072
hidden_size: int = 3072
intermediate_size: int = 9216
num_hidden_layers: int = 26
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 262144
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [2]
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
@@ -970,15 +946,6 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Ministral3_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@@ -52,26 +52,6 @@ class TaskImageContent(BaseModel):
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class TaskVideoContentUrl(BaseModel):
url: str = Field(...)
class TaskVideoContent(BaseModel):
type: str = Field("video_url")
video_url: TaskVideoContentUrl = Field(...)
role: str = Field("reference_video")
class TaskAudioContentUrl(BaseModel):
url: str = Field(...)
class TaskAudioContent(BaseModel):
type: str = Field("audio_url")
audio_url: TaskAudioContentUrl = Field(...)
role: str = Field("reference_audio")
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
@@ -84,17 +64,6 @@ class Image2VideoTaskCreationRequest(BaseModel):
generate_audio: bool | None = Field(...)
class Seedance2TaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
generate_audio: bool | None = Field(None)
resolution: str | None = Field(None)
ratio: str | None = Field(None)
duration: int | None = Field(None, ge=4, le=15)
seed: int | None = Field(None, ge=0, le=2147483647)
watermark: bool | None = Field(None)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
@@ -108,27 +77,12 @@ class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusUsage(BaseModel):
completion_tokens: int = Field(0)
total_tokens: int = Field(0)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
usage: TaskStatusUsage | None = Field(None)
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False): 0.007,
("dreamina-seedance-2-0-260128", True): 0.0043,
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
}
RECOMMENDED_PRESETS = [
@@ -158,12 +112,6 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("Custom", None, None),
]
# Seedance 2.0 reference video pixel count limits per model.
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
}
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {

View File

@@ -8,23 +8,16 @@ from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
SEEDANCE2_PRICE_PER_1K_TOKENS,
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
VIDEO_TASKS_EXECUTION_TIME,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedance2TaskCreationRequest,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskAudioContent,
TaskAudioContentUrl,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
TaskVideoContent,
TaskVideoContentUrl,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
@@ -36,10 +29,7 @@ from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
poll_op,
sync_op,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@@ -56,56 +46,12 @@ SEEDREAM_MODELS = {
# Long-running tasks endpoints(e.g., video)
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/contents/generations/tasks" # + /{task_id}
SEEDANCE_MODELS = {
"Seedance 2.0": "dreamina-seedance-2-0-260128",
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
logger = logging.getLogger(__name__)
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
"""Validate reference video pixel count against Seedance 2.0 model limits."""
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
if not limits:
return
try:
w, h = video.get_dimensions()
except Exception:
return
pixels = w * h
min_px = limits.get("min")
max_px = limits.get("max")
if min_px and pixels < min_px:
raise ValueError(
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
)
if max_px and pixels > max_px:
raise ValueError(
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
)
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
if rate is None:
return None
def extractor(response: TaskStatusResponse) -> float | None:
if response.usage is None:
return None
return response.usage.total_tokens * 1.43 * rate / 1_000.0
return extractor
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@@ -389,7 +335,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
mp_provided = out_num_pixels / 1_000_000.0
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided."
f"Minimum image resolution for the selected model is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
@@ -1005,6 +952,33 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
)
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
data=payload,
response_model=TaskCreationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
status_extractor=lambda r: r.status,
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
for i in text_params:
if f"--{i} " in prompt:
@@ -1066,530 +1040,6 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
)
def _seedance2_text_inputs():
return [
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for video generation.",
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="Resolution of the output video.",
),
IO.Combo.Input(
"ratio",
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
tooltip="Aspect ratio of the output video.",
),
IO.Int.Input(
"duration",
default=7,
min=4,
max=15,
step=1,
tooltip="Duration of the output video in seconds (4-15).",
display_mode=IO.NumberDisplay.slider,
),
IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="Enable audio generation for the output video.",
),
]
class ByteDance2TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDance2TextToVideoNode",
display_name="ByteDance Seedance 2.0 Text to Video",
category="api node/video/ByteDance",
description="Generate video using Seedance 2.0 models based on a text prompt.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add a watermark to the video.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
expr="""
(
$rate480 := 10044;
$rate720 := 21600;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "720p" ? $rate720 : $rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
seed: int,
watermark: bool,
) -> IO.NodeOutput:
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
model_id = SEEDANCE_MODELS[model["model"]]
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
data=Seedance2TaskCreationRequest(
model=model_id,
content=[TaskTextContent(text=model["prompt"])],
generate_audio=model["generate_audio"],
resolution=model["resolution"],
ratio=model["ratio"],
duration=model["duration"],
seed=seed,
watermark=watermark,
),
response_model=TaskCreationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
class ByteDance2FirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDance2FirstLastFrameNode",
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
category="api node/video/ByteDance",
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
IO.Image.Input(
"first_frame",
tooltip="First frame image for the video.",
),
IO.Image.Input(
"last_frame",
tooltip="Last frame image for the video.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add a watermark to the video.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
expr="""
(
$rate480 := 10044;
$rate720 := 21600;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "720p" ? $rate720 : $rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
first_frame: Input.Image,
seed: int,
watermark: bool,
last_frame: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
model_id = SEEDANCE_MODELS[model["model"]]
content: list[TaskTextContent | TaskImageContent] = [
TaskTextContent(text=model["prompt"]),
TaskImageContent(
image_url=TaskImageContentUrl(
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
),
role="first_frame",
),
]
if last_frame is not None:
content.append(
TaskImageContent(
image_url=TaskImageContentUrl(
url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.")
),
role="last_frame",
),
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
data=Seedance2TaskCreationRequest(
model=model_id,
content=content,
generate_audio=model["generate_audio"],
resolution=model["resolution"],
ratio=model["ratio"],
duration=model["duration"],
seed=seed,
watermark=watermark,
),
response_model=TaskCreationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs():
return [
*_seedance2_text_inputs(),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_image"),
names=[
"image_1",
"image_2",
"image_3",
"image_4",
"image_5",
"image_6",
"image_7",
"image_8",
"image_9",
],
min=0,
),
),
IO.Autogrow.Input(
"reference_videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("reference_video"),
names=["video_1", "video_2", "video_3"],
min=0,
),
),
IO.Autogrow.Input(
"reference_audios",
template=IO.Autogrow.TemplateNames(
IO.Audio.Input("reference_audio"),
names=["audio_1", "audio_2", "audio_3"],
min=0,
),
),
]
class ByteDance2ReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDance2ReferenceNode",
display_name="ByteDance Seedance 2.0 Reference to Video",
category="api node/video/ByteDance",
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add a watermark to the video.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "model.resolution", "model.duration"],
input_groups=["model.reference_videos"],
),
expr="""
(
$rate480 := 10044;
$rate720 := 21600;
$m := widgets.model;
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "720p" ? $rate720 : $rate480;
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
$minVideoFactor := $ceil($dur * 5 / 3);
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
$maxVideoCost := (15 + $dur) * $rate * $videoPricePer1K / 1000;
$hasVideo
? {
"type": "range_usd",
"min_usd": $minVideoCost,
"max_usd": $maxVideoCost,
"format": {"approximate": true}
}
: {
"type": "usd",
"usd": $noVideoCost,
"format": {"approximate": true}
}
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
seed: int,
watermark: bool,
) -> IO.NodeOutput:
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
reference_images = model.get("reference_images", {})
reference_videos = model.get("reference_videos", {})
reference_audios = model.get("reference_audios", {})
if not reference_images and not reference_videos:
raise ValueError("At least one reference image or video is required.")
model_id = SEEDANCE_MODELS[model["model"]]
has_video_input = len(reference_videos) > 0
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):
video = reference_videos[key]
_validate_ref_video_pixels(video, model_id, i)
try:
dur = video.get_duration()
if dur < 1.8:
raise ValueError(f"Reference video {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
total_video_duration += dur
except ValueError:
raise
except Exception:
pass
if total_video_duration > 15.1:
raise ValueError(f"Total reference video duration is {total_video_duration:.1f}s. Maximum is 15.1 seconds.")
total_audio_duration = 0.0
for i, key in enumerate(reference_audios, 1):
audio = reference_audios[key]
dur = int(audio["waveform"].shape[-1]) / int(audio["sample_rate"])
if dur < 1.8:
raise ValueError(f"Reference audio {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
total_audio_duration += dur
if total_audio_duration > 15.1:
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
TaskTextContent(text=model["prompt"]),
]
for i, key in enumerate(reference_images, 1):
content.append(
TaskImageContent(
image_url=TaskImageContentUrl(
url=await upload_image_to_comfyapi(
cls,
image=reference_images[key],
wait_label=f"Uploading image {i}",
),
),
role="reference_image",
),
)
for i, key in enumerate(reference_videos, 1):
content.append(
TaskVideoContent(
video_url=TaskVideoContentUrl(
url=await upload_video_to_comfyapi(
cls,
reference_videos[key],
wait_label=f"Uploading video {i}",
),
),
),
)
for key in reference_audios:
content.append(
TaskAudioContent(
audio_url=TaskAudioContentUrl(
url=await upload_audio_to_comfyapi(
cls,
reference_audios[key],
container_format="mp3",
codec_name="libmp3lame",
mime_type="audio/mpeg",
),
),
),
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
data=Seedance2TaskCreationRequest(
model=model_id,
content=content,
generate_audio=model["generate_audio"],
resolution=model["resolution"],
ratio=model["ratio"],
duration=model["duration"],
seed=seed,
watermark=watermark,
),
response_model=TaskCreationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
data=payload,
response_model=TaskCreationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
status_extractor=lambda r: r.status,
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1600,9 +1050,6 @@ class ByteDanceExtension(ComfyExtension):
ByteDanceImageToVideoNode,
ByteDanceFirstLastFrameNode,
ByteDanceImageReferenceNode,
ByteDance2TextToVideoNode,
ByteDance2FirstLastFrameNode,
ByteDance2ReferenceNode,
]

View File

@@ -558,7 +558,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$refs := $lookup(inputGroups, "model.reference_images");
$refs := inputGroups["model.reference_images"];
$rate := $res = "720p" ? 0.07 : 0.05;
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
{"type":"usd","usd": $price}

View File

@@ -1,287 +0,0 @@
import base64
import json
import logging
import time
from urllib.parse import urljoin
import aiohttp
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.util import (
ApiEndpoint,
audio_bytes_to_audio_input,
upload_video_to_comfyapi,
validate_string,
)
from comfy_api_nodes.util._helpers import (
default_base_url,
get_auth_header,
get_node_id,
is_processing_interrupted,
)
from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted
from server import PromptServer
logger = logging.getLogger(__name__)
class SoniloVideoToMusic(IO.ComfyNode):
"""Generate music from video using Sonilo's AI model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="SoniloVideoToMusic",
display_name="Sonilo Video to Music",
category="api node/audio/Sonilo",
description="Generate music from video content using Sonilo's AI model. "
"Analyzes the video and creates matching music.",
inputs=[
IO.Video.Input(
"video",
tooltip="Input video to generate music from. Maximum duration: 6 minutes.",
),
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Optional text prompt to guide music generation. "
"Leave empty for best quality - the model will fully analyze the video content.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
"service but kept for graph consistency.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}',
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
video_url = await upload_video_to_comfyapi(cls, video, max_duration=360)
form = aiohttp.FormData()
form.add_field("video_url", video_url)
if prompt.strip():
form.add_field("prompt", prompt.strip())
audio_bytes = await _stream_sonilo_music(
cls,
ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"),
form,
)
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
class SoniloTextToMusic(IO.ComfyNode):
"""Generate music from a text prompt using Sonilo's AI model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music",
category="api node/audio/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt describing the music to generate.",
),
IO.Int.Input(
"duration",
default=0,
min=0,
max=360,
tooltip="Target duration in seconds. Set to 0 to let the model "
"infer the duration from the prompt. Maximum: 6 minutes.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
"service but kept for graph consistency.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""
(
widgets.duration > 0
? {"type":"usd","usd": 0.005 * widgets.duration}
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
duration: int = 0,
seed: int = 0,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
form = aiohttp.FormData()
form.add_field("prompt", prompt)
if duration > 0:
form.add_field("duration", str(duration))
audio_bytes = await _stream_sonilo_music(
cls,
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
form,
)
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
async def _stream_sonilo_music(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
form: aiohttp.FormData,
) -> bytes:
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
headers: dict[str, str] = {}
headers.update(get_auth_header(cls))
headers.update(endpoint.headers)
node_id = get_node_id(cls)
start_ts = time.monotonic()
last_chunk_status_ts = 0.0
audio_streams: dict[int, list[bytes]] = {}
title: str | None = None
timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
PromptServer.instance.send_progress_text("Status: Queued", node_id)
async with session.post(url, data=form, headers=headers) as resp:
if resp.status >= 400:
msg = await _extract_error_message(resp)
raise Exception(f"Sonilo API error ({resp.status}): {msg}")
while True:
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
raw_line = await resp.content.readline()
if not raw_line:
break
line = raw_line.decode("utf-8").strip()
if not line:
continue
try:
evt = json.loads(line)
except json.JSONDecodeError:
logger.warning("Sonilo: skipping malformed NDJSON line")
continue
evt_type = evt.get("type")
if evt_type == "error":
code = evt.get("code", "UNKNOWN")
message = evt.get("message", "Unknown error")
raise Exception(f"Sonilo generation error ({code}): {message}")
if evt_type == "duration":
duration_sec = evt.get("duration_sec")
if duration_sec is not None:
PromptServer.instance.send_progress_text(
f"Status: Generating\nVideo duration: {duration_sec:.1f}s",
node_id,
)
elif evt_type in ("titles", "title"):
# v2m sends a "titles" list, t2m sends a scalar "title"
if evt_type == "titles":
titles = evt.get("titles", [])
if titles:
title = titles[0]
else:
title = evt.get("title") or title
if title:
PromptServer.instance.send_progress_text(
f"Status: Generating\nTitle: {title}",
node_id,
)
elif evt_type == "audio_chunk":
stream_idx = evt.get("stream_index", 0)
chunk_data = base64.b64decode(evt["data"])
if stream_idx not in audio_streams:
audio_streams[stream_idx] = []
audio_streams[stream_idx].append(chunk_data)
now = time.monotonic()
if now - last_chunk_status_ts >= 1.0:
total_chunks = sum(len(chunks) for chunks in audio_streams.values())
elapsed = int(now - start_ts)
status_lines = ["Status: Receiving audio"]
if title:
status_lines.append(f"Title: {title}")
status_lines.append(f"Chunks received: {total_chunks}")
status_lines.append(f"Time elapsed: {elapsed}s")
PromptServer.instance.send_progress_text("\n".join(status_lines), node_id)
last_chunk_status_ts = now
elif evt_type == "complete":
break
if not audio_streams:
raise Exception("Sonilo API returned no audio data.")
PromptServer.instance.send_progress_text("Status: Completed", node_id)
selected_stream = 0 if 0 in audio_streams else min(audio_streams)
return b"".join(audio_streams[selected_stream])
async def _extract_error_message(resp: aiohttp.ClientResponse) -> str:
"""Extract a human-readable error message from an HTTP error response."""
try:
error_body = await resp.json()
detail = error_body.get("detail", {})
if isinstance(detail, dict):
return detail.get("message", str(detail))
return str(detail)
except Exception:
return await resp.text()
class SoniloExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [SoniloVideoToMusic, SoniloTextToMusic]
async def comfy_entrypoint() -> SoniloExtension:
return SoniloExtension()

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from inspect import cleandoc
from typing import TYPE_CHECKING
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode):
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_WorkUnits",
display_name="MultiGPU CFG Split",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Model.Input("model"),
io.Int.Input("max_gpus", default=2, min=1, step=1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
return io.NodeOutput(model)
class MultiGPUOptionsNode(io.ComfyNode):
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_Options",
display_name="MultiGPU Options",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Int.Input("device_index", default=0, min=0, max=64),
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
],
outputs=[
io.Custom("GPU_OPTIONS").Output(),
],
)
@classmethod
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return io.NodeOutput(gpu_options)
class MultiGPUExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
MultiGPUCFGSplitNode,
# MultiGPUOptionsNode,
]
async def comfy_entrypoint() -> MultiGPUExtension:
return MultiGPUExtension()

View File

@@ -11,7 +11,7 @@ class PreviewAny():
"required": {"source": (IO.ANY, {})},
}
RETURN_TYPES = (IO.STRING,)
RETURN_TYPES = ()
FUNCTION = "main"
OUTPUT_NODE = True
@@ -33,7 +33,7 @@ class PreviewAny():
except Exception:
value = 'source exists, but could not be serialized.'
return {"ui": {"text": (value,)}, "result": (value,)}
return {"ui": {"text": (value,)}}
NODE_CLASS_MAPPINGS = {
"PreviewAny": PreviewAny,

View File

@@ -32,12 +32,10 @@ class RTDETR_detect(io.ComfyNode):
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
comfy.model_management.load_model_gpu(model)
results = []
for i in range(0, B, 32):
batch = image[i:i + 32]
image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
results.extend(model.model.diffusion_model(image_in, (W, H)))
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
all_bbox_dicts = []

View File

@@ -1,6 +1,5 @@
import torch
import comfy.utils
import comfy.model_management
import numpy as np
import math
import colorsys
@@ -411,9 +410,7 @@ class SDPoseDrawKeypoints(io.ComfyNode):
pose_outputs.append(canvas)
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
final_pose_output = torch.from_numpy(pose_outputs_np).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype()) / 255.0
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
return io.NodeOutput(final_pose_output)
class SDPoseKeypointExtractor(io.ComfyNode):
@@ -462,27 +459,6 @@ class SDPoseKeypointExtractor(io.ComfyNode):
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
def _resize_to_model(imgs):
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
h, w = imgs.shape[-3], imgs.shape[-2]
scale = min(model_h / h, model_w / w)
sh, sw = int(round(h * scale)), int(round(w * scale))
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
chw = imgs.permute(0, 3, 1, 2).float()
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
return padded.permute(0, 2, 3, 1), scale, pt, pl
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
"""Remap keypoints from model space back to original image space."""
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
invalid = kp[..., 0] < 0
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
kp[invalid] = -1
return kp
def _run_on_latent(latent_batch):
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
nonlocal captured_feat
@@ -528,19 +504,36 @@ class SDPoseKeypointExtractor(io.ComfyNode):
if x2 <= x1 or y2 <= y1:
continue
crop_h_px, crop_w_px = y2 - y1, x2 - x1
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
scale = min(model_h / crop_h_px, model_w / crop_w_px)
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop)
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
# remove padding offset, undo scale, offset to full-image coordinates.
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
img_keypoints.append(kp)
img_scores.append(sc_batch[0])
img_scores.append(sc)
else:
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
latent_img = vae.encode(img_resized)
# No bboxes for this image run on the full image
latent_img = vae.encode(img)
kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
img_keypoints.append(kp_batch[0])
img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints)
@@ -548,16 +541,19 @@ class SDPoseKeypointExtractor(io.ComfyNode):
pbar.update(1)
else: # full-image mode, batched
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
latent_batch = vae.encode(batch_resized)
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
for batch_start in range(0, total_images, batch_size):
batch_end = min(batch_start + batch_size, total_images)
latent_batch = vae.encode(image[batch_start:batch_end])
kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
all_keypoints.append([kp])
all_scores.append([sc])
tqdm_pbar.update(1)
pbar.update(len(kp_batch))
pbar.update(batch_end - batch_start)
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
return io.NodeOutput(openpose_frames)

View File

@@ -6,7 +6,6 @@ import comfy.utils
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
try:
from spandrel_extra_arches import EXTRA_REGISTRY
@@ -79,15 +78,13 @@ class ImageUpscaleWithModel(io.ComfyNode):
tile = 512
overlap = 32
output_device = comfy.model_management.intermediate_device()
oom = True
try:
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except Exception as e:
model_management.raise_non_oom(e)
@@ -97,7 +94,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
finally:
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return io.NodeOutput(s)
upscale = execute # TODO: remove

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.19.1"
__version__ = "0.18.1"

View File

@@ -2412,6 +2412,7 @@ async def init_builtin_extra_nodes():
"nodes_lt_audio.py",
"nodes_lt.py",
"nodes_hooks.py",
"nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.19.1"
version = "0.18.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.11
comfyui-workflow-templates==0.9.54
comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.44
comfyui-embedded-docs==0.4.3
torch
torchsde