From 9125613b53fc6af219d5a3db1d5b202ccc3f41b3 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 16 Jan 2026 08:09:07 +0200 Subject: [PATCH 01/58] feat(api-nodes): extend ByteDance nodes with seedance-1-5-pro model (#11871) --- comfy_api_nodes/apis/bytedance_api.py | 7 ++ comfy_api_nodes/nodes_bytedance.py | 104 +++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index b8c2f618b..400648cca 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -65,11 +65,13 @@ class TaskImageContent(BaseModel): class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent] = Field(..., min_length=1) + generate_audio: bool | None = Field(...) class Image2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + generate_audio: bool | None = Field(...) class TaskCreationResponse(BaseModel): @@ -141,4 +143,9 @@ VIDEO_TASKS_EXECUTION_TIME = { "720p": 65, "1080p": 100, }, + "seedance-1-5-pro-251215": { + "480p": 80, + "720p": 100, + "1080p": 150, + }, } diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f09a4a0ed..9cb1ca004 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -477,7 +477,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-t2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -528,6 +533,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -552,7 +563,10 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) @@ -567,7 +581,11 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ) return await process_video_task( cls, - payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, + ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -584,7 +602,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-i2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -639,6 +662,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -664,7 +693,10 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) @@ -686,6 +718,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -703,7 +736,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + options=["seedance-1-5-pro-251215", "seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( @@ -762,6 +795,12 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -788,7 +827,10 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): @@ -821,6 +863,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -896,7 +939,41 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, + ), ) @classmethod @@ -967,10 +1044,15 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: PRICE_BADGE_VIDEO = IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution", "generate_audio"]), expr=""" ( $priceByModel := { + "seedance-1-5-pro": { + "480p":[0.12,0.12], + "720p":[0.26,0.26], + "1080p":[0.58,0.59] + }, "seedance-1-0-pro": { "480p":[0.23,0.24], "720p":[0.51,0.56], @@ -989,6 +1071,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( }; $model := widgets.model; $modelKey := + $contains($model, "seedance-1-5-pro") ? "seedance-1-5-pro" : $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : "seedance-1-0-lite"; @@ -1002,11 +1085,12 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( $min10s := $baseRange[0]; $max10s := $baseRange[1]; $scale := widgets.duration / 10; - $minCost := $min10s * $scale; - $maxCost := $max10s * $scale; + $audioMultiplier := ($modelKey = "seedance-1-5-pro" and widgets.generate_audio) ? 2 : 1; + $minCost := $min10s * $scale * $audioMultiplier; + $maxCost := $max10s * $scale * $audioMultiplier; ($minCost = $maxCost) - ? {"type":"usd","usd": $minCost} - : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ? {"type":"usd","usd": $minCost, "format": { "approximate": true }} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost, "format": { "approximate": true }} ) """, ) From 0c6b36c6ac1c34515cdf28f777a63074cd6d563d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 17 Jan 2026 06:22:50 +0800 Subject: [PATCH 02/58] chore: update workflow templates to v0.8.11 (#11918) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 996701550..3876274f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.10 +comfyui-workflow-templates==0.8.11 comfyui-embedded-docs==0.4.0 torch torchsde From 7ac999bf3069b06648a749212f59237080a75591 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:02:28 -0800 Subject: [PATCH 03/58] Add image sizes to clip vision outputs. (#11923) --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 66f2a9d9c..b28bf636c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -66,6 +66,7 @@ class ClipVisionModel(): outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0] if self.return_all_hidden_states: all_hs = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = all_hs[:, -2] From 00c775950aec5c563f532c8db08dae5e6adc24eb Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Sun, 18 Jan 2026 01:18:04 +0000 Subject: [PATCH 04/58] Update readme rdna3 nightly url (#11937) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e25f3cda7..123cc9472 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ These have less hardware support than the builds above but they work on windows. RDNA 3 (RX 7000 series): -```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/``` RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): From 0fd10ffa09588e0fc7f576ab7d0c93e97ad5fbb0 Mon Sep 17 00:00:00 2001 From: Theephop <144770658+TheephopWS@users.noreply.github.com> Date: Sun, 18 Jan 2026 09:18:24 +0800 Subject: [PATCH 05/58] fix: use .cpu() for waveform conversion in AudioFrame creation (#11787) --- comfy_api/latest/_input_impl/video_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index ea35c6062..1405d0b81 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -374,7 +374,7 @@ class VideoFromComponents(VideoInput): if audio_stream and self.__components.audio: waveform = self.__components.audio['waveform'] waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] - frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame.sample_rate = audio_sample_rate frame.pts = 0 output.mux(audio_stream.encode(frame)) From 190c4416cce3b3b97b628935e001d796d565bfc9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 17 Jan 2026 18:20:35 -0800 Subject: [PATCH 06/58] Bump comfy-kitchen dependency to version 0.2.7 (#11941) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3876274f9..622256973 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.6 +comfy-kitchen>=0.2.7 #non essential dependencies: kornia>=0.7.1 From ac26065e6125871e2a742db6960f183fa037a75d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 04:52:45 +0200 Subject: [PATCH 07/58] chore(api-nodes): remove non-used; extract model to separate files (#11927) * chore(api-nodes): remove non-used; extract model to separate files * chore(api-nodes): remove non-needed prefix in filenames --- comfy_api_nodes/README.md | 65 ---- comfy_api_nodes/apis/{bfl_api.py => bfl.py} | 0 .../apis/{bytedance_api.py => bytedance.py} | 0 .../apis/{gemini_api.py => gemini.py} | 0 comfy_api_nodes/apis/ideogram.py | 292 +++++++++++++++++ .../apis/{kling_api.py => kling.py} | 0 comfy_api_nodes/apis/{luma_api.py => luma.py} | 0 .../apis/{minimax_api.py => minimax.py} | 0 comfy_api_nodes/apis/moonvalley.py | 152 +++++++++ comfy_api_nodes/apis/openai.py | 170 ++++++++++ comfy_api_nodes/apis/openai_api.py | 52 --- .../apis/{pixverse_api.py => pixverse.py} | 0 .../apis/{recraft_api.py => recraft.py} | 0 .../apis/{rodin_api.py => rodin.py} | 0 comfy_api_nodes/apis/runway.py | 127 ++++++++ .../apis/{stability_api.py => stability.py} | 0 .../apis/{topaz_api.py => topaz.py} | 4 +- .../apis/{tripo_api.py => tripo.py} | 0 comfy_api_nodes/apis/{veo_api.py => veo.py} | 0 comfy_api_nodes/mapper_utils.py | 116 ------- comfy_api_nodes/nodes_bfl.py | 2 +- comfy_api_nodes/nodes_bytedance.py | 2 +- comfy_api_nodes/nodes_gemini.py | 2 +- comfy_api_nodes/nodes_ideogram.py | 2 +- comfy_api_nodes/nodes_kling.py | 2 +- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_moonvalley.py | 2 +- comfy_api_nodes/nodes_openai.py | 86 ++--- comfy_api_nodes/nodes_pixverse.py | 2 +- comfy_api_nodes/nodes_recraft.py | 2 +- comfy_api_nodes/nodes_rodin.py | 2 +- comfy_api_nodes/nodes_runway.py | 2 +- comfy_api_nodes/nodes_stability.py | 2 +- comfy_api_nodes/nodes_topaz.py | 55 ++-- comfy_api_nodes/nodes_tripo.py | 2 +- comfy_api_nodes/nodes_veo2.py | 2 +- comfy_api_nodes/redocly-dev.yaml | 10 - comfy_api_nodes/redocly.yaml | 10 - .../comfy_api_nodes_test/mapper_utils_test.py | 297 ------------------ 40 files changed, 825 insertions(+), 641 deletions(-) delete mode 100644 comfy_api_nodes/README.md rename comfy_api_nodes/apis/{bfl_api.py => bfl.py} (100%) rename comfy_api_nodes/apis/{bytedance_api.py => bytedance.py} (100%) rename comfy_api_nodes/apis/{gemini_api.py => gemini.py} (100%) create mode 100644 comfy_api_nodes/apis/ideogram.py rename comfy_api_nodes/apis/{kling_api.py => kling.py} (100%) rename comfy_api_nodes/apis/{luma_api.py => luma.py} (100%) rename comfy_api_nodes/apis/{minimax_api.py => minimax.py} (100%) create mode 100644 comfy_api_nodes/apis/moonvalley.py create mode 100644 comfy_api_nodes/apis/openai.py delete mode 100644 comfy_api_nodes/apis/openai_api.py rename comfy_api_nodes/apis/{pixverse_api.py => pixverse.py} (100%) rename comfy_api_nodes/apis/{recraft_api.py => recraft.py} (100%) rename comfy_api_nodes/apis/{rodin_api.py => rodin.py} (100%) create mode 100644 comfy_api_nodes/apis/runway.py rename comfy_api_nodes/apis/{stability_api.py => stability.py} (100%) rename comfy_api_nodes/apis/{topaz_api.py => topaz.py} (97%) rename comfy_api_nodes/apis/{tripo_api.py => tripo.py} (100%) rename comfy_api_nodes/apis/{veo_api.py => veo.py} (100%) delete mode 100644 comfy_api_nodes/mapper_utils.py delete mode 100644 comfy_api_nodes/redocly-dev.yaml delete mode 100644 comfy_api_nodes/redocly.yaml delete mode 100644 tests-unit/comfy_api_nodes_test/mapper_utils_test.py diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md deleted file mode 100644 index f56d6c860..000000000 --- a/comfy_api_nodes/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# ComfyUI API Nodes - -## Introduction - -Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview). - -## Development - -While developing, you should be testing against the Staging environment. To test against staging: - -**Install ComfyUI_frontend** - -Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication. - -> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication. - -```bash -python run main.py --comfy-api-base https://stagingapi.comfy.org -``` - -To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging. - -API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. - -### Redocly Instructions - -**Tip** -When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet. - -Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. - -```bash -# Download the OpenAPI file from staging server. -curl -o openapi.yaml https://stagingapi.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` - - -# Merging to Master - -Before merging to comfyanonymous/ComfyUI master, follow these steps: - -1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes. -1. Make sure the ComfyUI API is deployed to prod with your changes. -1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file. - -```bash -# Download the OpenAPI file from prod server. -curl -o openapi.yaml https://api.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl.py similarity index 100% rename from comfy_api_nodes/apis/bfl_api.py rename to comfy_api_nodes/apis/bfl.py diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance.py similarity index 100% rename from comfy_api_nodes/apis/bytedance_api.py rename to comfy_api_nodes/apis/bytedance.py diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini.py similarity index 100% rename from comfy_api_nodes/apis/gemini_api.py rename to comfy_api_nodes/apis/gemini.py diff --git a/comfy_api_nodes/apis/ideogram.py b/comfy_api_nodes/apis/ideogram.py new file mode 100644 index 000000000..737e18e3b --- /dev/null +++ b/comfy_api_nodes/apis/ideogram.py @@ -0,0 +1,292 @@ +from enum import Enum +from typing import Optional, List, Dict, Any, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel, StrictBytes + + +class IdeogramColorPalette1(BaseModel): + name: str = Field(..., description='Name of the preset color palette') + + +class Member(BaseModel): + color: Optional[str] = Field( + None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$' + ) + weight: Optional[float] = Field( + None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0 + ) + + +class IdeogramColorPalette2(BaseModel): + members: List[Member] = Field( + ..., description='Array of color definitions with optional weights' + ) + + +class IdeogramColorPalette( + RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]] +): + root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field( + ..., + description='A color palette specification that can either use a preset name or explicit color definitions with weights', + ) + + +class ImageRequest(BaseModel): + aspect_ratio: Optional[str] = Field( + None, + description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + magic_prompt_option: Optional[str] = Field( + None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." + ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + seed: Optional[int] = Field( + None, + description='Optional. A number between 0 and 2147483647.', + ge=0, + le=2147483647, + ) + style_type: Optional[str] = Field( + None, + description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", + ) + + +class IdeogramGenerateRequest(BaseModel): + image_request: ImageRequest = Field( + ..., description='The image generation request parameters.' + ) + + +class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) + prompt: Optional[str] = Field( + None, description='The prompt used to generate this image.' + ) + resolution: Optional[str] = Field( + None, description="The resolution of the generated image (e.g., '1024x1024')." + ) + seed: Optional[int] = Field( + None, description='The seed value used for this generation.' + ) + style_type: Optional[str] = Field( + None, + description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", + ) + url: Optional[str] = Field(None, description='URL to the generated image.') + + +class IdeogramGenerateResponse(BaseModel): + created: Optional[datetime] = Field( + None, description='Timestamp when the generation was created.' + ) + data: Optional[List[Datum]] = Field( + None, description='Array of generated image information.' + ) + + +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') + + +class Datum1(BaseModel): + is_image_safe: Optional[bool] = None + prompt: Optional[str] = None + resolution: Optional[str] = None + seed: Optional[int] = None + style_type: Optional[str] = None + url: Optional[str] = None + + +class IdeogramV3IdeogramResponse(BaseModel): + created: Optional[datetime] = None + data: Optional[List[Datum1]] = None + + +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' + + +class IdeogramV3ReframeRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + rendering_speed: Optional[RenderingSpeed1] = None + resolution: str + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class MagicPrompt(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + aspect_ratio: Optional[str] = None + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + image_weight: Optional[int] = Field(50, ge=1, le=100) + magic_prompt: Optional[MagicPrompt] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + resolution: Optional[str] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + style_type: Optional[StyleType] = None + + +class IdeogramV3ReplaceBackgroundRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + magic_prompt: Optional[MagicPrompt] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) + + +class MagicPrompt2(str, Enum): + ON = 'ON' + OFF = 'OFF' + + +class StyleType1(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' + + +class RenderingSpeed(str, Enum): + DEFAULT = 'DEFAULT' + TURBO = 'TURBO' + QUALITY = 'QUALITY' + + +class IdeogramV3EditRequest(BaseModel): + color_palette: Optional[IdeogramColorPalette] = None + image: Optional[StrictBytes] = Field( + None, + description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', + ) + magic_prompt: Optional[str] = Field( + None, + description='Determine if MagicPrompt should be used in generating the request or not.', + ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.' + ) + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' + ) + rendering_speed: RenderingSpeed + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, + description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.', + ) + style_reference_images: Optional[List[StrictBytes]] = Field( + None, + description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) + + +class IdeogramV3Request(BaseModel): + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + color_palette: Optional[ColorPalette] = None + magic_prompt: Optional[MagicPrompt2] = Field( + None, description='Whether to enable magic prompt enhancement' + ) + negative_prompt: Optional[str] = Field( + None, description='Text prompt specifying what to avoid in the generation' + ) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 + ) + prompt: str = Field(..., description='The text prompt for image generation') + rendering_speed: RenderingSpeed + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + style_type: Optional[StyleType1] = Field( + None, description='The type of style to apply' + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling.py similarity index 100% rename from comfy_api_nodes/apis/kling_api.py rename to comfy_api_nodes/apis/kling.py diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma.py similarity index 100% rename from comfy_api_nodes/apis/luma_api.py rename to comfy_api_nodes/apis/luma.py diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax.py similarity index 100% rename from comfy_api_nodes/apis/minimax_api.py rename to comfy_api_nodes/apis/minimax.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py new file mode 100644 index 000000000..7ec7a4ade --- /dev/null +++ b/comfy_api_nodes/apis/moonvalley.py @@ -0,0 +1,152 @@ +from enum import Enum +from typing import Optional, Dict, Any + +from pydantic import BaseModel, Field, StrictBytes + + +class MoonvalleyPromptResponse(BaseModel): + error: Optional[Dict[str, Any]] = None + frame_conditioning: Optional[Dict[str, Any]] = None + id: Optional[str] = None + inference_params: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + model_params: Optional[Dict[str, Any]] = None + output_url: Optional[str] = None + prompt_text: Optional[str] = None + status: Optional[str] = None + + +class MoonvalleyTextToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 75, description='Number of cooldown steps (calculated based on num_frames)' + ) + fps: Optional[int] = Field( + 24, description='Frames per second of the generated video' + ) + guidance_scale: Optional[float] = Field( + 10, description='Guidance scale for generation control' + ) + height: Optional[int] = Field( + 1080, description='Height of the generated video in pixels' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + num_frames: Optional[int] = Field(64, description='Number of frames to generate') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 0, description='Number of warmup steps (calculated based on num_frames)' + ) + width: Optional[int] = Field( + 1920, description='Width of the generated video in pixels' + ) + + +class MoonvalleyTextToVideoRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class MoonvalleyUploadFileRequest(BaseModel): + file: Optional[StrictBytes] = None + + +class MoonvalleyUploadFileResponse(BaseModel): + access_url: Optional[str] = None + + +class MoonvalleyVideoToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 36, description='Number of cooldown steps (calculated based on num_frames)' + ) + guidance_scale: Optional[float] = Field( + 15, description='Guidance scale for generation control' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 24, description='Number of warmup steps (calculated based on num_frames)' + ) + + +class ControlType(str, Enum): + motion_control = 'motion_control' + pose_control = 'pose_control' + + +class MoonvalleyVideoToVideoRequest(BaseModel): + control_type: ControlType = Field( + ..., description='Supported types for video control' + ) + inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None + prompt_text: str = Field(..., description='Describes the video to generate') + video_url: str = Field(..., description='Url to control video') + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for notifications' + ) diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py new file mode 100644 index 000000000..b85ef252b --- /dev/null +++ b/comfy_api_nodes/apis/openai.py @@ -0,0 +1,170 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = Field(None) + text_tokens: int | None = Field(None) + + +class Usage(BaseModel): + input_tokens: int | None = Field(None) + input_tokens_details: InputTokensDetails | None = Field(None) + output_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = Field(None) + usage: Usage | None = Field(None) + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") + + +class ModelResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + top_p: float | None = Field( + 1, + description="Controls diversity of the response via nucleus sampling", + ge=0.0, + le=1.0, + ) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + previous_response_id: str | None = Field(None) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class OutputTokensDetails(BaseModel): + reasoning_tokens: int = Field(..., description="The number of reasoning tokens.") + + +class CachedTokensDetails(BaseModel): + cached_tokens: int = Field( + ..., + description="The number of tokens that were retrieved from the cache.", + ) + + +class ResponseUsage(BaseModel): + input_tokens: int = Field(..., description="The number of input tokens.") + input_tokens_details: CachedTokensDetails = Field(...) + output_tokens: int = Field(..., description="The number of output tokens.") + output_tokens_details: OutputTokensDetails = Field(...) + total_tokens: int = Field(..., description="The total number of tokens used.") + + +class InputTextContent(BaseModel): + text: str = Field(..., description="The text input to the model.") + type: str = Field("input_text") + + +class OutputContent(BaseModel): + type: str = Field(..., description="The type of output content") + text: str | None = Field(None, description="The text content") + data: str | None = Field(None, description="Base64-encoded audio data") + transcript: str | None = Field(None, description="Transcript of the audio") + + +class OutputMessage(BaseModel): + type: str = Field(..., description="The type of output item") + content: list[OutputContent] | None = Field(None, description="The content of the message") + role: str | None = Field(None, description="The role of the message") + + +class OpenAIResponse(ModelResponseProperties, ResponseProperties): + created_at: float | None = Field( + None, + description="Unix timestamp (in seconds) of when this Response was created.", + ) + error: ResponseError | None = Field(None) + id: str | None = Field(None, description="Unique identifier for this Response.") + object: str | None = Field(None, description="The object type of this resource - always set to `response`.") + output: list[OutputMessage] | None = Field(None) + parallel_tool_calls: bool | None = Field(True) + status: str | None = Field( + None, + description="One of `completed`, `failed`, `in_progress`, or `incomplete`.", + ) + usage: ResponseUsage | None = Field(None) + + +class InputImageContent(BaseModel): + detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.") + file_id: str | None = Field(None) + image_url: str | None = Field(None) + type: str = Field(..., description="The type of the input item. Always `input_image`.") + + +class InputFileContent(BaseModel): + file_data: str | None = Field(None) + file_id: str | None = Field(None) + filename: str | None = Field(None, description="The name of the file to be sent to the model.") + type: str = Field(..., description="The type of the input item. Always `input_file`.") + + +class InputMessage(BaseModel): + content: list[InputTextContent | InputImageContent | InputFileContent] = Field( + ..., + description="A list of one or many input items to the model, containing different content types.", + ) + role: str | None = Field(None) + type: str | None = Field(None) + + +class OpenAICreateResponse(ModelResponseProperties, ResponseProperties): + include: str | None = Field(None) + input: list[InputMessage] = Field(...) + parallel_tool_calls: bool | None = Field( + True, description="Whether to allow the model to run tool calls in parallel." + ) + store: bool | None = Field( + True, + description="Whether to store the generated model response for later retrieval via API.", + ) + stream: bool | None = Field(False) + usage: ResponseUsage | None = Field(None) diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py deleted file mode 100644 index ae5bb2673..000000000 --- a/comfy_api_nodes/apis/openai_api.py +++ /dev/null @@ -1,52 +0,0 @@ -from pydantic import BaseModel, Field - - -class Datum2(BaseModel): - b64_json: str | None = Field(None, description="Base64 encoded image data") - revised_prompt: str | None = Field(None, description="Revised prompt") - url: str | None = Field(None, description="URL of the image") - - -class InputTokensDetails(BaseModel): - image_tokens: int | None = None - text_tokens: int | None = None - - -class Usage(BaseModel): - input_tokens: int | None = None - input_tokens_details: InputTokensDetails | None = None - output_tokens: int | None = None - total_tokens: int | None = None - - -class OpenAIImageGenerationResponse(BaseModel): - data: list[Datum2] | None = None - usage: Usage | None = None - - -class OpenAIImageEditRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str = Field(...) - moderation: str | None = Field(None) - n: int | None = Field(None, description="The number of images to generate") - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - size: str | None = Field(None, description="Size of the output image") - - -class OpenAIImageGenerationRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str | None = Field(None) - moderation: str | None = Field(None) - n: int | None = Field( - None, - description="The number of images to generate.", - ) - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="The quality of the generated image") - size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse.py similarity index 100% rename from comfy_api_nodes/apis/pixverse_api.py rename to comfy_api_nodes/apis/pixverse.py diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft.py similarity index 100% rename from comfy_api_nodes/apis/recraft_api.py rename to comfy_api_nodes/apis/recraft.py diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin.py similarity index 100% rename from comfy_api_nodes/apis/rodin_api.py rename to comfy_api_nodes/apis/rodin.py diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py new file mode 100644 index 000000000..df6f2b845 --- /dev/null +++ b/comfy_api_nodes/apis/runway.py @@ -0,0 +1,127 @@ +from enum import Enum +from typing import Optional, List, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + position: Position = Field( + ..., + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", + ) + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' + ) + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayImageToVideoRequest(BaseModel): + duration: RunwayDurationEnum + model: RunwayModelEnum + promptImage: RunwayPromptImageObject + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + ratio: RunwayAspectRatioEnum + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayTaskStatusResponse(BaseModel): + createdAt: datetime = Field(..., description='Task creation timestamp') + id: str = Field(..., description='Task ID') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + progress: Optional[float] = Field( + None, + description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', + ge=0.0, + le=1.0, + ) + status: RunwayTaskStatusEnum + + +class Model4(str, Enum): + gen4_image = 'gen4_image' + + +class ReferenceImage(BaseModel): + uri: Optional[str] = Field( + None, description='A HTTPS URL or data URI containing an encoded image' + ) + + +class RunwayTextToImageAspectRatioEnum(str, Enum): + field_1920_1080 = '1920:1080' + field_1080_1920 = '1080:1920' + field_1024_1024 = '1024:1024' + field_1360_768 = '1360:768' + field_1080_1080 = '1080:1080' + field_1168_880 = '1168:880' + field_1440_1080 = '1440:1080' + field_1080_1440 = '1080:1440' + field_1808_768 = '1808:768' + field_2112_912 = '2112:912' + + +class RunwayTextToImageRequest(BaseModel): + model: Model4 = Field(..., description='Model to use for generation') + promptText: str = Field( + ..., description='Text prompt for the image generation', max_length=1000 + ) + ratio: RunwayTextToImageAspectRatioEnum + referenceImages: Optional[List[ReferenceImage]] = Field( + None, description='Array of reference images to guide the generation' + ) + + +class RunwayTextToImageResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability.py similarity index 100% rename from comfy_api_nodes/apis/stability_api.py rename to comfy_api_nodes/apis/stability.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz.py similarity index 97% rename from comfy_api_nodes/apis/topaz_api.py rename to comfy_api_nodes/apis/topaz.py index 4d9e62e72..a9e6235a7 100644 --- a/comfy_api_nodes/apis/topaz_api.py +++ b/comfy_api_nodes/apis/topaz.py @@ -41,7 +41,7 @@ class Resolution(BaseModel): height: int = Field(...) -class CreateCreateVideoRequestSource(BaseModel): +class CreateVideoRequestSource(BaseModel): container: str = Field(...) size: int = Field(..., description="Size of the video file in bytes") duration: int = Field(..., description="Duration of the video file in seconds") @@ -89,7 +89,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): - source: CreateCreateVideoRequestSource = Field(...) + source: CreateVideoRequestSource = Field(...) filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo.py similarity index 100% rename from comfy_api_nodes/apis/tripo_api.py rename to comfy_api_nodes/apis/tripo.py diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo.py similarity index 100% rename from comfy_api_nodes/apis/veo_api.py rename to comfy_api_nodes/apis/veo.py diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py deleted file mode 100644 index 6fab8f4bb..000000000 --- a/comfy_api_nodes/mapper_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -from enum import Enum - -from pydantic.fields import FieldInfo -from pydantic import BaseModel -from pydantic_core import PydanticUndefined - -from comfy.comfy_types.node_typing import IO, InputTypeOptions - -NodeInput = tuple[IO, InputTypeOptions] - - -def _create_base_config(field_info: FieldInfo) -> InputTypeOptions: - config = {} - if hasattr(field_info, "default") and field_info.default is not PydanticUndefined: - config["default"] = field_info.default - if hasattr(field_info, "description") and field_info.description is not None: - config["tooltip"] = field_info.description - return config - - -def _get_number_constraints_config(field_info: FieldInfo) -> dict: - config = {} - if hasattr(field_info, "metadata"): - metadata = field_info.metadata - for constraint in metadata: - if hasattr(constraint, "ge"): - config["min"] = constraint.ge - if hasattr(constraint, "le"): - config["max"] = constraint.le - if hasattr(constraint, "multiple_of"): - config["step"] = constraint.multiple_of - return config - - -def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.IMAGE, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.STRING, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.FLOAT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.INT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_combo_input( - field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs -) -> NodeInput: - combo_config = {} - if enum_type is not None: - combo_config["options"] = [option.value for option in enum_type] - combo_config = { - **combo_config, - **_create_base_config(field_info), - **kwargs, - } - return IO.COMBO, combo_config - - -def model_field_to_node_input( - input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs -) -> NodeInput: - """ - Maps a field from a Pydantic model to a Comfy node input. - - Args: - input_type: The type of the input. - base_model: The Pydantic model to map the field from. - field_name: The name of the field to map. - **kwargs: Additional key/values to include in the input options. - - Note: - For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically. - - Example: - >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True) - >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum) - >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True) - """ - field_info: FieldInfo = base_model.model_fields[field_name] - result: NodeInput - - if input_type == IO.IMAGE: - result = _model_field_to_image_input(field_info, **kwargs) - elif input_type == IO.STRING: - result = _model_field_to_string_input(field_info, **kwargs) - elif input_type == IO.FLOAT: - result = _model_field_to_float_input(field_info, **kwargs) - elif input_type == IO.INT: - result = _model_field_to_int_input(field_info, **kwargs) - elif input_type == IO.COMBO: - result = _model_field_to_combo_input(field_info, **kwargs) - else: - message = f"Invalid input type: {input_type}" - raise ValueError(message) - - return result diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 76021ef7f..61c3b4503 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bfl_api import ( +from comfy_api_nodes.apis.bfl import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 9cb1ca004..486801150 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bytedance_api import ( +from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index a2daea50a..3b31caa7b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -14,7 +14,7 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input, Types -from comfy_api_nodes.apis.gemini_api import ( +from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 827b3523a..feaf7a858 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.ideogram import ( IdeogramGenerateRequest, IdeogramGenerateResponse, ImageRequest, diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 05dde88b1..3ec71530b 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -49,7 +49,7 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.kling_api import ( +from comfy_api_nodes.apis.kling import ( ImageToVideoWithAudioRequest, MotionControlRequest, OmniImageParamImage, diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 95cb442e5..9ed6cd299 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.luma_api import ( +from comfy_api_nodes.apis.luma import ( LumaAspectRatio, LumaCharacterRef, LumaConceptChain, diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 43a15d50d..b5d0b461f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.minimax_api import ( +from comfy_api_nodes.apis.minimax import ( MinimaxFileRetrieveResponse, MiniMaxModel, MinimaxTaskResultResponse, diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 769b171b7..08315fa2b 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -3,7 +3,7 @@ import logging from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.moonvalley import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoRequest, diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 2f144c5c3..a12acc06b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -10,24 +10,18 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( - CreateModelResponseProperties, - Detail, - InputContent, +from comfy_api_nodes.apis.openai import ( InputFileContent, InputImageContent, InputMessage, - InputMessageContentList, InputTextContent, - Item, + ModelResponseProperties, OpenAICreateResponse, - OpenAIResponse, - OutputContent, -) -from comfy_api_nodes.apis.openai_api import ( OpenAIImageEditRequest, OpenAIImageGenerationRequest, OpenAIImageGenerationResponse, + OpenAIResponse, + OutputContent, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -266,7 +260,7 @@ class OpenAIDalle3(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -384,7 +378,7 @@ class OpenAIGPTImage1(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -500,8 +494,8 @@ class OpenAIGPTImage1(IO.ComfyNode): files = [] batch_size = image.shape[0] for i in range(batch_size): - single_image = image[i: i + 1] - scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() + single_image = image[i : i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -523,7 +517,7 @@ class OpenAIGPTImage1(IO.ComfyNode): rgba_mask = torch.zeros(height, width, 4, device="cpu") rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze() mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) @@ -696,29 +690,23 @@ class OpenAIChatNode(IO.ComfyNode): ) @classmethod - def get_message_content_from_response( - cls, response: OpenAIResponse - ) -> list[OutputContent]: + def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: - if output.root.type == "message": - return output.root.content + if output.type == "message": + return output.content raise TypeError("No output message found in response") @classmethod - def get_text_from_message_content( - cls, message_content: list[OutputContent] - ) -> str: + def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str: """Extract text content from message content.""" for content_item in message_content: - if content_item.root.type == "output_text": - return str(content_item.root.text) + if content_item.type == "output_text": + return str(content_item.text) return "No text output found in response" @classmethod - def tensor_to_input_image_content( - cls, image: torch.Tensor, detail_level: Detail = "auto" - ) -> InputImageContent: + def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( detail=detail_level, @@ -732,9 +720,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, image: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - ) -> InputMessageContentList: + ) -> list[InputTextContent | InputImageContent | InputFileContent]: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ + content_list: list[InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -746,13 +734,9 @@ class OpenAIChatNode(IO.ComfyNode): type="input_image", ) ) - if files is not None: content_list.extend(files) - - return InputMessageContentList( - root=content_list, - ) + return content_list @classmethod async def execute( @@ -762,7 +746,7 @@ class OpenAIChatNode(IO.ComfyNode): model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, images: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - advanced_options: CreateModelResponseProperties | None = None, + advanced_options: ModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -773,36 +757,28 @@ class OpenAIChatNode(IO.ComfyNode): response_model=OpenAIResponse, data=OpenAICreateResponse( input=[ - Item( - root=InputMessage( - content=cls.create_input_message_contents( - prompt, images, files - ), - role="user", - ) + InputMessage( + content=cls.create_input_message_contents(prompt, images, files), + role="user", ), ], store=True, stream=False, model=model, previous_response_id=None, - **( - advanced_options.model_dump(exclude_none=True) - if advanced_options - else {} - ), + **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), ), ) response_id = create_response.id # Get result output result_response = await poll_op( - cls, - ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), - response_model=OpenAIResponse, - status_extractor=lambda response: response.status, - completed_statuses=["incomplete", "completed"] - ) + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"], + ) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) @@ -923,7 +899,7 @@ class OpenAIChatConfig(IO.ComfyNode): remove depending on model choice. """ return IO.NodeOutput( - CreateModelResponseProperties( + ModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 86ddb3ab9..e17a24ae7 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.pixverse_api import ( +from comfy_api_nodes.apis.pixverse import ( PixverseTextVideoRequest, PixverseImageVideoRequest, PixverseTransitionVideoRequest, diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 05dc151ad..c01bcaece 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -8,7 +8,7 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.recraft_api import ( +from comfy_api_nodes.apis.recraft import ( RecraftColor, RecraftColorChain, RecraftControls, diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index b4420cb93..3ffdc8b90 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -14,7 +14,7 @@ from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image -from comfy_api_nodes.apis.rodin_api import ( +from comfy_api_nodes.apis.rodin import ( Rodin3DGenerateRequest, Rodin3DGenerateResponse, Rodin3DCheckStatusRequest, diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index d19fdb365..573170ba2 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -16,7 +16,7 @@ from enum import Enum from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.runway import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 5c48c1f1e..5665109cf 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -3,7 +3,7 @@ from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.stability_api import ( +from comfy_api_nodes.apis.stability import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, StabilityAsyncResponse, diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index 9dc5f45bc..c052e7656 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -5,7 +5,24 @@ import aiohttp from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.apis.topaz import ( + CreateVideoRequest, + CreateVideoRequestSource, + CreateVideoResponse, + ImageAsyncTaskResponse, + ImageDownloadResponse, + ImageEnhanceRequest, + ImageStatusResponse, + OutputInformationVideo, + Resolution, + VideoAcceptResponse, + VideoCompleteUploadRequest, + VideoCompleteUploadRequestPart, + VideoCompleteUploadResponse, + VideoEnhancementFilter, + VideoFrameInterpolationFilter, + VideoStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -153,13 +170,13 @@ class TopazImageEnhance(IO.ComfyNode): if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") download_url = await upload_images_to_comfyapi( - cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096 + cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096 ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), - response_model=topaz_api.ImageAsyncTaskResponse, - data=topaz_api.ImageEnhanceRequest( + response_model=ImageAsyncTaskResponse, + data=ImageEnhanceRequest( model=model, prompt=prompt, subject_detection=subject_detection, @@ -181,7 +198,7 @@ class TopazImageEnhance(IO.ComfyNode): await poll_op( cls, poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), - response_model=topaz_api.ImageStatusResponse, + response_model=ImageStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, @@ -193,7 +210,7 @@ class TopazImageEnhance(IO.ComfyNode): results = await sync_op( cls, ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), - response_model=topaz_api.ImageDownloadResponse, + response_model=ImageDownloadResponse, monitor_progress=False, ) return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) @@ -331,7 +348,7 @@ class TopazVideoEnhance(IO.ComfyNode): if target_height % 2 != 0: target_height += 1 filters.append( - topaz_api.VideoEnhancementFilter( + VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), @@ -340,7 +357,7 @@ class TopazVideoEnhance(IO.ComfyNode): if interpolation_enabled: target_frame_rate = interpolation_frame_rate filters.append( - topaz_api.VideoFrameInterpolationFilter( + VideoFrameInterpolationFilter( model=interpolation_model, slowmo=interpolation_slowmo, fps=interpolation_frame_rate, @@ -351,19 +368,19 @@ class TopazVideoEnhance(IO.ComfyNode): initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/video/", method="POST"), - response_model=topaz_api.CreateVideoResponse, - data=topaz_api.CreateVideoRequest( - source=topaz_api.CreateCreateVideoRequestSource( + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), frameCount=video.get_frame_count(), frameRate=src_frame_rate, - resolution=topaz_api.Resolution(width=src_width, height=src_height), + resolution=Resolution(width=src_width, height=src_height), ), filters=filters, - output=topaz_api.OutputInformationVideo( - resolution=topaz_api.Resolution(width=target_width, height=target_height), + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), frameRate=target_frame_rate, audioCodec="AAC", audioTransfer="Copy", @@ -379,7 +396,7 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/accept", method="PATCH", ), - response_model=topaz_api.VideoAcceptResponse, + response_model=VideoAcceptResponse, wait_label="Preparing upload", final_label_on_success="Upload started", ) @@ -402,10 +419,10 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", method="PATCH", ), - response_model=topaz_api.VideoCompleteUploadResponse, - data=topaz_api.VideoCompleteUploadRequest( + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( uploadResults=[ - topaz_api.VideoCompleteUploadRequestPart( + VideoCompleteUploadRequestPart( partNum=1, eTag=upload_etag, ), @@ -417,7 +434,7 @@ class TopazVideoEnhance(IO.ComfyNode): final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), - response_model=topaz_api.VideoStatusResponse, + response_model=VideoStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index aa790143d..5abf27b4d 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.tripo_api import ( +from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, TripoConvertModelRequest, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index c14d6ad68..2a202fc3b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -4,7 +4,7 @@ from io import BytesIO from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis.veo_api import ( +from comfy_api_nodes.apis.veo import ( VeoGenVidPollRequest, VeoGenVidPollResponse, VeoGenVidRequest, diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml deleted file mode 100644 index d9e3cab70..000000000 --- a/comfy_api_nodes/redocly-dev.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. -# This is used for development purposes to generate stubs for unreleased API endpoints. -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes'] - matchStrategy: all diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml deleted file mode 100644 index d102345b1..000000000 --- a/comfy_api_nodes/redocly.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. - -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes', 'Released'] - matchStrategy: all diff --git a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py b/tests-unit/comfy_api_nodes_test/mapper_utils_test.py deleted file mode 100644 index 69488f691..000000000 --- a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py +++ /dev/null @@ -1,297 +0,0 @@ -from typing import Optional -from enum import Enum - -from pydantic import BaseModel, Field - -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.mapper_utils import model_field_to_node_input - - -def test_model_field_to_float_input(): - """Tests mapping a float field with constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field( - default=0.5, - description="Flexibility in video generation", - ge=0.0, - le=1.0, - multiple_of=0.001, - ) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - "tooltip": "Flexibility in video generation", - "min": 0.0, - "max": 1.0, - "step": 0.001, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_float_input_no_constraints(): - """Tests mapping a float field with no constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field(default=0.5) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_int_input(): - """Tests mapping an int field with constraints.""" - - class ModelWithIntField(BaseModel): - num_frames: Optional[int] = Field( - default=10, - description="Number of frames to generate", - ge=1, - le=100, - multiple_of=1, - ) - - expected_output = ( - IO.INT, - { - "default": 10, - "tooltip": "Number of frames to generate", - "min": 1, - "max": 100, - "step": 1, - }, - ) - - actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input_multiline(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - "multiline": True, - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithStringField, "prompt", multiline=True - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input(): - """Tests mapping a combo field.""" - - class MockEnum(str, Enum): - option_1 = "option 1" - option_2 = "option 2" - option_3 = "option 3" - - class ModelWithComboField(BaseModel): - model_name: Optional[MockEnum] = Field("option 1", description="Model Name") - - expected_output = ( - IO.COMBO, - { - "options": ["option 1", "option 2", "option 3"], - "default": "option 1", - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input_no_options(): - """Tests mapping a combo field with no options.""" - - class ModelWithComboField(BaseModel): - model_name: Optional[str] = Field(description="Model Name") - - expected_output = ( - IO.COMBO, - { - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_image_input(): - """Tests mapping an image field.""" - - class ModelWithImageField(BaseModel): - image: Optional[str] = Field( - default=None, - description="An image for the video generation", - ) - - expected_output = ( - IO.IMAGE, - { - "default": None, - "tooltip": "An image for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_description(): - """Tests mapping a field with no description.""" - - class ModelWithNoDescriptionField(BaseModel): - field: Optional[str] = Field(default="default value") - - expected_output = ( - IO.STRING, - { - "default": "default value", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDescriptionField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_default(): - """Tests mapping a field with no default.""" - - class ModelWithNoDefaultField(BaseModel): - field: Optional[str] = Field(description="A field with no default") - - expected_output = ( - IO.STRING, - { - "tooltip": "A field with no default", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_metadata(): - """Tests mapping a field with no metadata or properties defined on the schema.""" - - class ModelWithNoMetadataField(BaseModel): - field: Optional[str] = Field() - - expected_output = ( - IO.STRING, - {}, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoMetadataField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_default_is_none(): - """ - Tests mapping a field with a default of `None`. - I.e., the default field should be included as the schema explicitly sets it to `None`. - """ - - class ModelWithNoneDefaultField(BaseModel): - field: Optional[str] = Field( - default=None, description="A field with a default of None" - ) - - expected_output = ( - IO.STRING, - { - "default": None, - "tooltip": "A field with a default of None", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoneDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] From f7ca41ff6226eecbf6c9ee475c1de714cb8f04e9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 04:57:57 +0200 Subject: [PATCH 08/58] chore(api-nodes): remove check for pyav>=14.2 in code (it was added to requirements.txt long ago) (#11934) --- comfy_api_nodes/canary.py | 10 ---------- nodes.py | 3 --- 2 files changed, 13 deletions(-) delete mode 100644 comfy_api_nodes/canary.py diff --git a/comfy_api_nodes/canary.py b/comfy_api_nodes/canary.py deleted file mode 100644 index 4df7590b6..000000000 --- a/comfy_api_nodes/canary.py +++ /dev/null @@ -1,10 +0,0 @@ -import av - -ver = av.__version__.split(".") -if int(ver[0]) < 14: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -if int(ver[0]) == 14 and int(ver[1]) < 2: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -NODE_CLASS_MAPPINGS = {} diff --git a/nodes.py b/nodes.py index f19d5fd1c..8b5279b36 100644 --- a/nodes.py +++ b/nodes.py @@ -2409,9 +2409,6 @@ async def init_builtin_api_nodes(): "nodes_wan.py", ] - if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): - return api_nodes_files - import_failed = [] for node_file in api_nodes_files: if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): From a498556d0dcde3d7a7c19e1f5c733c8c2a2ffb10 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 17 Jan 2026 19:06:03 -0800 Subject: [PATCH 09/58] feat: add advanced parameter to Input classes for advanced widgets support (#11939) Add 'advanced' boolean parameter to Input and WidgetInput base classes and propagate to all typed Input subclasses (Boolean, Int, Float, String, Combo, MultiCombo, Webcam, MultiType, MatchType, ImageCompare). When set to True, the frontend will hide these inputs by default in a collapsible 'Advanced Inputs' section in the right side panel, reducing visual clutter for power-user options. This enables nodes to expose advanced configuration options (like encoding parameters, quality settings, etc.) without overwhelming typical users. Frontend support: ComfyUI_frontend PR #7812 --- comfy_api/latest/_io.py | 47 ++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e6a0d1821..c30d92aaa 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -153,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -162,6 +162,7 @@ class Input(_IO_V3): self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} self.rawLink = raw_link + self.advanced = advanced def as_dict(self): return prune_dict({ @@ -170,6 +171,7 @@ class Input(_IO_V3): "tooltip": self.tooltip, "lazy": self.lazy, "rawLink": self.rawLink, + "advanced": self.advanced, }) | prune_dict(self.extra_dict) def get_io_type(self): @@ -184,8 +186,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -242,8 +244,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.label_on = label_on self.label_off = label_off self.default: bool @@ -262,8 +264,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -288,8 +290,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -314,8 +316,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -350,12 +352,13 @@ class Combo(ComfyTypeIO): socketless: bool=None, extra_dict=None, raw_link: bool=None, + advanced: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,8 +390,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) + socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -421,9 +424,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) @comfytype(io_type="MASK") @@ -776,7 +779,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -789,7 +792,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self._io_types = types @property @@ -843,8 +846,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.template = template def as_dict(self): @@ -1119,8 +1122,8 @@ class ImageCompare(ComfyTypeI): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True): - super().__init__(id, display_name, optional, tooltip, None, None, socketless) + socketless: bool=True, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, None, socketless, None, None, None, None, advanced) def as_dict(self): return super().as_dict() From 034fac70549dd9c35b155b80a3ff627ad07b1015 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 08:40:39 +0200 Subject: [PATCH 10/58] chore(api-nodes): auto-discover all nodes_*.py files to avoid merge conflicts when adding new API nodes (#11943) --- nodes.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/nodes.py b/nodes.py index 8b5279b36..cba8eacc2 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import torch import os import sys import json +import glob import hashlib import inspect import traceback @@ -2384,35 +2385,12 @@ async def init_builtin_extra_nodes(): async def init_builtin_api_nodes(): api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") - api_nodes_files = [ - "nodes_ideogram.py", - "nodes_openai.py", - "nodes_minimax.py", - "nodes_veo2.py", - "nodes_kling.py", - "nodes_bfl.py", - "nodes_bytedance.py", - "nodes_ltxv.py", - "nodes_luma.py", - "nodes_recraft.py", - "nodes_pixverse.py", - "nodes_stability.py", - "nodes_runway.py", - "nodes_sora.py", - "nodes_topaz.py", - "nodes_tripo.py", - "nodes_meshy.py", - "nodes_moonvalley.py", - "nodes_rodin.py", - "nodes_gemini.py", - "nodes_vidu.py", - "nodes_wan.py", - ] + api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py"))) import_failed = [] for node_file in api_nodes_files: - if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): - import_failed.append(node_file) + if not await load_custom_node(node_file, module_parent="comfy_api_nodes"): + import_failed.append(os.path.basename(node_file)) return import_failed From 1a72bf20469dee31ad156f819c14f0172cbad222 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 18 Jan 2026 19:53:43 -0800 Subject: [PATCH 11/58] Readme update. (#11957) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 123cc9472..c56e05d07 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Works fully offline: core will never download anything unless you want to. -- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes` - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 -torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. +torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. ### Instructions: @@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` From 866a4619db2db56c77a86e5fc9968a2454928627 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 20 Jan 2026 06:21:35 +0800 Subject: [PATCH 12/58] chore: update workflow templates to v0.8.14 (#11974) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 622256973..312c7c137 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.11 +comfyui-workflow-templates==0.8.14 comfyui-embedded-docs==0.4.0 torch torchsde From b931b37e30bb19b6e13ad8623e193ccdaf671a23 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 19 Jan 2026 16:47:14 -0800 Subject: [PATCH 13/58] feat(api-nodes): add Bria Edit node (#11978) Co-authored-by: Alexander Piskun --- comfy_api_nodes/apis/bria.py | 61 +++++++++ comfy_api_nodes/nodes_bria.py | 198 ++++++++++++++++++++++++++++ comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/conversions.py | 6 + 4 files changed, 267 insertions(+) create mode 100644 comfy_api_nodes/apis/bria.py create mode 100644 comfy_api_nodes/nodes_bria.py diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py new file mode 100644 index 000000000..9119cacc6 --- /dev/null +++ b/comfy_api_nodes/apis/bria.py @@ -0,0 +1,61 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputModerationSettings(TypedDict): + prompt_content_moderation: bool + visual_input_moderation: bool + visual_output_moderation: bool + + +class BriaEditImageRequest(BaseModel): + instruction: str | None = Field(...) + structured_instruction: str | None = Field( + ..., + description="Use this instead of instruction for precise, programmatic control.", + ) + images: list[str] = Field( + ..., + description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.", + ) + mask: str | None = Field( + None, + description="Mask image (black and white). Black areas will be preserved, white areas will be edited. " + "If omitted, the edit applies to the entire image. " + "The input image and the the input mask must be of the same size.", + ) + negative_prompt: str | None = Field(None) + guidance_scale: float = Field(...) + model_version: str = Field(...) + steps_num: int = Field(...) + seed: int = Field(...) + ip_signal: bool = Field( + False, + description="If true, returns a warning for potential IP content in the instruction.", + ) + prompt_content_moderation: bool = Field( + False, description="If true, returns 422 on instruction moderation failure." + ) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on images or mask moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + + +class BriaStatusResponse(BaseModel): + request_id: str = Field(...) + status_url: str = Field(...) + warning: str | None = Field(None) + + +class BriaResult(BaseModel): + structured_prompt: str = Field(...) + image_url: str = Field(...) + + +class BriaResponse(BaseModel): + status: str = Field(...) + result: BriaResult | None = Field(None) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py new file mode 100644 index 000000000..72a3055a7 --- /dev/null +++ b/comfy_api_nodes/nodes_bria.py @@ -0,0 +1,198 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bria import ( + BriaEditImageRequest, + BriaResponse, + BriaStatusResponse, + InputModerationSettings, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + convert_mask_to_image, + download_url_to_image_tensor, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, +) + + +class BriaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaImageEditNode", + display_name="Bria Image Edit", + category="api node/image/Bria", + description="Edit images using Bria latest model", + inputs=[ + IO.Combo.Input("model", options=["FIBO"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.String.Input( + "structured_prompt", + multiline=True, + default="", + tooltip="A string containing the structured edit prompt in JSON format. " + "Use this instead of usual prompt for precise, programmatic control.", + ), + IO.Int.Input( + "seed", + default=1, + min=1, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Float.Input( + "guidance_scale", + default=3, + min=3, + max=5, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely.", + ), + IO.Int.Input( + "steps", + default=50, + min=20, + max=50, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "prompt_content_moderation", default=False + ), + IO.Boolean.Input( + "visual_input_moderation", default=False + ), + IO.Boolean.Input( + "visual_output_moderation", default=True + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Moderation settings", + ), + IO.Mask.Input( + "mask", + tooltip="If omitted, the edit applies to the entire image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(display_name="structured_prompt"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str, + structured_prompt: str, + seed: int, + guidance_scale: float, + steps: int, + moderation: InputModerationSettings, + mask: Input.Image | None = None, + ) -> IO.NodeOutput: + if not prompt and not structured_prompt: + raise ValueError( + "One of prompt or structured_prompt is required to be non-empty." + ) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + mask_url = None + if mask is not None: + mask_url = ( + await upload_images_to_comfyapi( + cls, + convert_mask_to_image(mask), + max_images=1, + mime_type="image/png", + wait_label="Uploading mask", + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), + data=BriaEditImageRequest( + instruction=prompt if prompt else None, + structured_instruction=structured_prompt if structured_prompt else None, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + wait_label="Uploading image", + ), + mask=mask_url, + negative_prompt=negative_prompt if negative_prompt else None, + guidance_scale=guidance_scale, + seed=seed, + model_version=model, + steps_num=steps, + prompt_content_moderation=moderation.get( + "prompt_content_moderation", False + ), + visual_input_content_moderation=moderation.get( + "visual_input_moderation", False + ), + visual_output_content_moderation=moderation.get( + "visual_output_moderation", False + ), + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaResponse, + ) + return IO.NodeOutput( + await download_url_to_image_tensor(response.result.image_url), + response.result.structured_prompt, + ) + + +class BriaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BriaImageEditNode, + ] + + +async def comfy_entrypoint() -> BriaExtension: + return BriaExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 4cc22abfb..364976000 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -11,6 +11,7 @@ from .conversions import ( audio_input_to_mp3, audio_to_base64_string, bytesio_to_image_tensor, + convert_mask_to_image, downscale_image_tensor, image_tensor_pair_to_batch, pil_to_bytesio, @@ -72,6 +73,7 @@ __all__ = [ "audio_input_to_mp3", "audio_to_base64_string", "bytesio_to_image_tensor", + "convert_mask_to_image", "downscale_image_tensor", "image_tensor_pair_to_batch", "pil_to_bytesio", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 99c302a2a..546741b7b 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -451,6 +451,12 @@ def resize_mask_to_image( return mask +def convert_mask_to_image(mask: Input.Image) -> torch.Tensor: + """Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.""" + mask = mask.unsqueeze(-1) + return torch.cat([mask] * 3, dim=-1) + + def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: From 7458e20465a0efcf91eafc0c65d1929ab7b2238d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 19 Jan 2026 16:58:30 -0800 Subject: [PATCH 14/58] Make Autogrow validation work properly (#11977) * In-progress autogrow validation fixes - properly looks at required/optional inputs, now working on the edge case that all inputs are optional and nothing is plugged in (should just be an empty dictionary passed into node) * Allow autogrow to work with all inputs being optional * Revert accidentally pushed changes to nodes_logic.py --- comfy_api/latest/_io.py | 53 ++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index c30d92aaa..4969d3506 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1000,20 +1000,38 @@ class Autogrow(ComfyTypeI): names = [f"{prefix}{i}" for i in range(max)] # need to create a new input based on the contents of input template_input = None - for _, dict_input in input.items(): - # for now, get just the first value from dict_input + template_required = True + for _input_type, dict_input in input.items(): + # for now, get just the first value from dict_input; if not required, min can be ignored + if len(dict_input) == 0: + continue template_input = list(dict_input.values())[0] + template_required = _input_type == "required" + break + if template_input is None: + raise Exception("template_input could not be determined from required or optional; this should never happen.") new_dict = {} + new_dict_added_to = False + # first, add possible inputs into out_dict for i, name in enumerate(names): expected_id = finalize_prefix(curr_prefix, name) + # required + if i < min and template_required: + out_dict["required"][expected_id] = template_input + type_dict = new_dict.setdefault("required", {}) + # optional + else: + out_dict["optional"][expected_id] = template_input + type_dict = new_dict.setdefault("optional", {}) if expected_id in live_inputs: - # required - if i < min: - type_dict = new_dict.setdefault("required", {}) - # optional - else: - type_dict = new_dict.setdefault("optional", {}) + # NOTE: prefix gets added in parse_class_inputs type_dict[name] = template_input + new_dict_added_to = True + # account for the edge case that all inputs are optional and no values are received + if not new_dict_added_to: + finalized_prefix = finalize_prefix(curr_prefix) + out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix + out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @@ -1151,6 +1169,8 @@ class V3Data(TypedDict): 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + dynamic_paths_default_value: dict[str, Any] + 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' create_dynamic_tuple: bool 'When True, the value of the dynamic input will be in the format (value, path_key).' @@ -1504,6 +1524,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i "required": {}, "optional": {}, "dynamic_paths": {}, + "dynamic_paths_default_value": {}, } d = d.copy() # ignore hidden for parsing @@ -1513,8 +1534,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i out_dict["hidden"] = hidden v3_data = {} dynamic_paths = out_dict.pop("dynamic_paths", None) - if dynamic_paths is not None: + if dynamic_paths is not None and len(dynamic_paths) > 0: v3_data["dynamic_paths"] = dynamic_paths + # this list is used for autogrow, in the case all inputs are optional and no values are passed + dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) + if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: + v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value return out_dict, hidden, v3_data def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: @@ -1551,11 +1576,16 @@ def add_to_dict_v1(i: Input, d: dict): def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +class DynamicPathsDefaultValue: + EMPTY_DICT = "empty_dict" + def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): paths = v3_data.get("dynamic_paths", None) + default_value_dict = v3_data.get("dynamic_paths_default_value", {}) if paths is None: return values values = values.copy() + result = {} create_tuple = v3_data.get("create_dynamic_tuple", False) @@ -1569,6 +1599,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): if is_last: value = values.pop(key, None) + if value is None: + # see if a default value was provided for this key + default_option = default_value_dict.get(key, None) + if default_option == DynamicPathsDefaultValue.EMPTY_DICT: + value = {} if create_tuple: value = (value, key) current[p] = value From e0eacb06883c1f7ddf8af249cd461d7c2ebcbaae Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:00:36 -0800 Subject: [PATCH 15/58] Simpler way to implement the #11980 loras. (#11981) --- comfy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 2e33a4258..5e79fb449 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -639,6 +639,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "proj_out.bias": "linear2.bias", "attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_k.weight": "norm.key_norm.scale", + "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 + "attn.to_out.weight": "linear2.weight", # Flux 2 } for k in block_map: From 0da5a0fe58ae940726a61b94698e303fb39d73c1 Mon Sep 17 00:00:00 2001 From: rkfg Date: Tue, 20 Jan 2026 06:12:02 +0300 Subject: [PATCH 16/58] Convert mono audio to fake stereo for LTXV VAE encoding (#11965) --- comfy/ldm/lightricks/vae/audio_vae.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index a9111d3bd..29d9e6c29 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -189,9 +189,12 @@ class AudioVAE(torch.nn.Module): waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: - raise ValueError( - f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" - ) + if waveform.shape[1] == 1: + waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:]) + else: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) mel_spec = self.preprocessor.waveform_to_mel( waveform, waveform_sample_rate, device=self.device_manager.load_device From 70c91b8248e08492cf16bfebdc83579b801a6ee0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:32:40 -0800 Subject: [PATCH 17/58] Fix #11963 (#11982) --- comfy/text_encoders/ovis.py | 1 + comfy/text_encoders/z_image.py | 1 + 2 files changed, 2 insertions(+) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 5754424d2..2cc0867c3 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return OvisTEModel_ diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index 19adde0b7..ad41bfb1e 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return ZImageTEModel_ From 9d273d3ab1fb1d2c8b34de4d54cabe50a5a3e5bc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Jan 2026 22:40:18 -0500 Subject: [PATCH 18/58] ComfyUI v0.10.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index dbb57b4e5..952d413db 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.2" +__version__ = "0.10.0" diff --git a/pyproject.toml b/pyproject.toml index 9ea73da05..120b6c751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.2" +version = "0.10.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 2108167f9f70cfd4874945b31a916680f959a6d7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 20:17:38 -0800 Subject: [PATCH 19/58] Support zimage omni base model. (#11979) --- comfy/ldm/lumina/model.py | 317 ++++++++++++++++++++++++++++------- comfy/model_base.py | 30 ++++ comfy/model_detection.py | 3 + comfy_extras/nodes_zimage.py | 88 ++++++++++ nodes.py | 1 + 5 files changed, 381 insertions(+), 58 deletions(-) create mode 100644 comfy_extras/nodes_zimage.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index afbab2ac7..139f879a1 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension +import comfy.utils -def modulate(x, scale): - return x * (1 + scale.unsqueeze(1)) +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + +def modulate(x, scale, timestep_zero_index=None): + if timestep_zero_index is None: + return x * (1 + scale.unsqueeze(1)) + else: + scale = (1 + scale.unsqueeze(1)) + actual_batch = scale.size(0) // 2 + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= scale[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= scale[:actual_batch] + return x + + +def apply_gate(gate, x, timestep_zero_index=None): + if timestep_zero_index is None: + return gate * x + else: + actual_batch = gate.size(0) // 2 + + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= gate[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= gate[:actual_batch] + return x ############################################################################# # Core NextDiT Model # @@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + timestep_zero_index=None, transformer_options={}, ): """ @@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module): assert adaln_input is not None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) - x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2( clamp_fp16(self.attention( - modulate(self.attention_norm1(x), scale_msa), + modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index), x_mask, freqs_cis, transformer_options=transformer_options, - )) + ))), timestep_zero_index=timestep_zero_index ) - x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2( clamp_fp16(self.feed_forward( - modulate(self.ffn_norm1(x), scale_mlp), - )) + modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index), + ))), timestep_zero_index=timestep_zero_index ) else: assert adaln_input is None @@ -345,13 +389,37 @@ class FinalLayer(nn.Module): ), ) - def forward(self, x, c): + def forward(self, x, c, timestep_zero_index=None): scale = self.adaLN_modulation(c) - x = modulate(self.norm_final(x), scale) + x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index) x = self.linear(x) return x +def pad_zimage(feats, pad_token, pad_tokens_multiple): + pad_extra = (-feats.shape[1]) % pad_tokens_multiple + return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra + + +def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}): + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = start_t + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() + return x_pos_ids + + class NextDiT(nn.Module): """ Diffusion model with a Transformer backbone. @@ -378,6 +446,7 @@ class NextDiT(nn.Module): time_scale=1.0, pad_tokens_multiple=None, clip_text_dim=None, + siglip_feat_dim=None, image_model=None, device=None, dtype=None, @@ -491,6 +560,41 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) + + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + siglip_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.siglip_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + # This norm final is in the lumina 2.0 code but isn't actually used for anything. # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) @@ -531,70 +635,166 @@ class NextDiT(nn.Module): imgs = torch.stack(imgs, dim=0) return imgs - def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} - ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: - bsz = len(x) - pH = pW = self.patch_size - device = x[0].device - orig_x = x - - if self.pad_tokens_multiple is not None: - pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None): + if cap_feats is not None: + cap_feats = self.cap_embedder(cap_feats) + cap_feats_len = cap_feats.shape[1] + if self.pad_tokens_multiple is not None: + cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple) + else: + cap_feats_len = 0 + cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) - cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset + embeds = (cap_feats,) + freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + + def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}): + bsz = 1 + pH = pW = self.patch_size + device = x.device + embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) + + if not omni: + cap_feats_len = embeds[0].shape[1] + offset + embeds += (None,) + freqs_cis += (None,) + else: + cap_feats_len += offset + if siglip_feats is not None: + b, h, w, c = siglip_feats.shape + siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + siglip_pos_ids[:, :, 0] = cap_feats_len + 2 + siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten() + siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten() + if self.siglip_pad_token is not None: + siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check + siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) + else: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + + if siglip_feats is None: + embeds += (None,) + freqs_cis += (None,) + else: + embeds += (siglip_feats,) + freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),) B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - H_tokens, W_tokens = H // pH, W // pW - x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) - x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - + x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options) if self.pad_tokens_multiple is not None: - pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + embeds += (x,) + freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1 + + + def patchify_and_embed( + self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={} + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = x.shape[0] + cap_mask = None # TODO? + main_siglip = None + orig_x = x + + embeds = ([], [], []) + freqs_cis = ([], [], []) + leftover_cap = [] + + start_t = 0 + omni = len(ref_latents) > 0 + if omni: + for i, ref in enumerate(ref_latents): + if i < len(ref_contexts): + ref_con = ref_contexts[i] + else: + ref_con = None + if i < len(siglip_feats): + sig_feat = siglip_feats[i] + else: + sig_feat = None + + out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) + for i, e in enumerate(out[0]): + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + leftover_cap = ref_contexts[len(ref_latents):] + + H, W = x.shape[-2], x.shape[-1] + img_sizes = [(H, W)] * bsz + out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options) + img_len = out[0][-1].shape[1] + cap_len = out[0][0].shape[1] + for i, e in enumerate(out[0]): + if e is not None: + e = comfy.utils.repeat_to_batch_size(e, bsz) + embeds[i].append(e) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + + for cap in leftover_cap: + out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype) + cap_len += out[0][0].shape[1] + embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz)) + freqs_cis[0].append(out[1][0]) + start_t += out[2] patches = transformer_options.get("patches", {}) # refine context + cap_feats = torch.cat(embeds[0], dim=1) + cap_freqs_cis = torch.cat(freqs_cis[0], dim=1) for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + + feats = (cap_feats,) + fc = (cap_freqs_cis,) + + if omni: + siglip_mask = None + siglip_feats_combined = torch.cat(embeds[1], dim=1) + siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) + if self.siglip_refiner is not None: + for layer in self.siglip_refiner: + siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options) + feats += (siglip_feats_combined,) + fc += (siglip_feats_freqs_cis,) padded_img_mask = None + x = torch.cat(embeds[-1], dim=1) + fc_x = torch.cat(freqs_cis[-1], dim=1) + if omni: + timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])] + else: + timestep_zero_index = None + x_input = x for i, layer in enumerate(self.noise_refiner): - x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "noise_refiner" in patches: for p in patches["noise_refiner"]: - out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) if "img" in out: x = out["img"] - padded_full_embed = torch.cat((cap_feats, x), dim=1) + padded_full_embed = torch.cat(feats + (x,), dim=1) + if timestep_zero_index is not None: + ind = padded_full_embed.shape[1] - x.shape[1] + timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])] + timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1])) + mask = None - img_sizes = [(H, W)] * bsz - l_effective_cap_len = [cap_feats.shape[1]] * bsz - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz + return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -604,7 +804,11 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -619,8 +823,6 @@ class NextDiT(nn.Module): t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - if self.clip_text_pooled_proj is not None: pooled = kwargs.get("clip_text_pooled", None) if pooled is not None: @@ -632,7 +834,7 @@ class NextDiT(nn.Module): patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) transformer_options["total_blocks"] = len(self.layers) @@ -640,7 +842,7 @@ class NextDiT(nn.Module): img_input = img for i, layer in enumerate(self.layers): transformer_options["block_index"] = i - img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) @@ -649,8 +851,7 @@ class NextDiT(nn.Module): if "txt" in out: img[:, :cap_size[0]] = out["txt"] - img = self.final_layer(img, adaln_input) + img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index) img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -img diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..28ba2643e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1150,6 +1150,7 @@ class CosmosPredict2(BaseModel): class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1169,6 +1170,35 @@ class Lumina2(BaseModel): if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni + if clip_vision_outputs is not None and len(clip_vision_outputs) > 0: + sigfeats = [] + for clip_vision_output in clip_vision_outputs: + if clip_vision_output is not None: + image_size = clip_vision_output.image_sizes[0] + shape = clip_vision_output.last_hidden_state.shape + sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1])) + if len(sigfeats) > 0: + out['siglip_feats'] = comfy.conds.CONDList(sigfeats) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_contexts = kwargs.get("reference_latents_text_embeds", None) + if ref_contexts is not None: + out['ref_contexts'] = comfy.conds.CONDList(ref_contexts) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index aff5a50b9..42884f797 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["time_scale"] = 1000.0 if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 + sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) + if sig_weight is not None: + dit_config["siglip_feat_dim"] = sig_weight.shape[0] return dit_config diff --git a/comfy_extras/nodes_zimage.py b/comfy_extras/nodes_zimage.py new file mode 100644 index 000000000..2ee3c43b1 --- /dev/null +++ b/comfy_extras/nodes_zimage.py @@ -0,0 +1,88 @@ +import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import math +import comfy.utils + + +class TextEncodeZImageOmni(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeZImageOmni", + category="advanced/conditioning", + is_experimental=True, + inputs=[ + io.Clip.Input("clip"), + io.ClipVision.Input("image_encoder", optional=True), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Boolean.Input("auto_resize_images", default=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: + ref_latents = [] + images = list(filter(lambda a: a is not None, [image1, image2, image3])) + + prompt_list = [] + template = None + if len(images) > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1) + prompt_list += ["<|vision_end|><|im_end|>"] + template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>" + + encoded_images = [] + + for i, image in enumerate(images): + if image_encoder is not None: + encoded_images.append(image_encoder.encode_image(image)) + + if vae is not None: + if auto_resize_images: + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + ref_latents.append(vae.encode(image)) + + tokens = clip.tokenize(prompt, llama_template=template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + extra_text_embeds = [] + for p in prompt_list: + tokens = clip.tokenize(p, llama_template="{}") + text_embeds = clip.encode_from_tokens_scheduled(tokens) + extra_text_embeds.append(text_embeds[0][0]) + + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + if len(encoded_images) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True) + if len(extra_text_embeds) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True) + + return io.NodeOutput(conditioning) + + +class ZImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeZImageOmni, + ] + + +async def comfy_entrypoint() -> ZImageExtension: + return ZImageExtension() diff --git a/nodes.py b/nodes.py index cba8eacc2..ea5d6e525 100644 --- a/nodes.py +++ b/nodes.py @@ -2373,6 +2373,7 @@ async def init_builtin_extra_nodes(): "nodes_kandinsky5.py", "nodes_wanmove.py", "nodes_image_compare.py", + "nodes_zimage.py", ] import_failed = [] From 0fc3b6e3a6f1d8fdffca3a51cb4d10a06f4e079d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 20 Jan 2026 12:17:56 +0800 Subject: [PATCH 20/58] chore: update workflow templates to v0.8.15 (#11984) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 312c7c137..35543525d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.14 +comfyui-workflow-templates==0.8.15 comfyui-embedded-docs==0.4.0 torch torchsde From 4edb87aa50190139a38a2ccd6b6ee35ba9df4da1 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Tue, 20 Jan 2026 13:57:50 +0900 Subject: [PATCH 21/58] Bump comfyui-frontend-package to 1.37.11 (#11976) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35543525d..ec89dccd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.36.14 +comfyui-frontend-package==1.37.11 comfyui-workflow-templates==0.8.15 comfyui-embedded-docs==0.4.0 torch From 8ccc0c94fa0d8e43fffe7190e6a36551a53df54a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 21:32:00 -0800 Subject: [PATCH 22/58] Make omni stuff work on regular z image for easier testing. (#11985) --- comfy/ldm/lumina/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 139f879a1..b114d9e31 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -657,7 +657,7 @@ class NextDiT(nn.Module): device = x.device embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) - if not omni: + if (not omni) or self.siglip_embedder is None: cap_feats_len = embeds[0].shape[1] + offset embeds += (None,) freqs_cis += (None,) @@ -675,8 +675,9 @@ class NextDiT(nn.Module): siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) else: - siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) - siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + if self.siglip_pad_token is not None: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) if siglip_feats is None: embeds += (None,) @@ -724,8 +725,9 @@ class NextDiT(nn.Module): out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) for i, e in enumerate(out[0]): - embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) - freqs_cis[i].append(out[1][i]) + if e is not None: + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) start_t = out[2] leftover_cap = ref_contexts[len(ref_latents):] @@ -759,7 +761,7 @@ class NextDiT(nn.Module): feats = (cap_feats,) fc = (cap_freqs_cis,) - if omni: + if omni and len(embeds[1]) > 0: siglip_mask = None siglip_feats_combined = torch.cat(embeds[1], dim=1) siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) From ddc541ffdae0fe626de5a33192001f31c6ab93c6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 20 Jan 2026 23:05:40 +0200 Subject: [PATCH 23/58] feat(api-nodes): add WaveSpeed nodes (#11945) --- comfy_api_nodes/apis/wavespeed.py | 35 ++++++ comfy_api_nodes/nodes_wavespeed.py | 178 +++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 comfy_api_nodes/apis/wavespeed.py create mode 100644 comfy_api_nodes/nodes_wavespeed.py diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py new file mode 100644 index 000000000..07a7bfa5d --- /dev/null +++ b/comfy_api_nodes/apis/wavespeed.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class SeedVR2ImageRequest(BaseModel): + image: str = Field(...) + target_resolution: str = Field(...) + output_format: str = Field("png") + enable_sync_mode: bool = Field(False) + + +class FlashVSRRequest(BaseModel): + target_resolution: str = Field(...) + video: str = Field(...) + duration: float = Field(...) + + +class TaskCreatedDataResponse(BaseModel): + id: str = Field(...) + + +class TaskCreatedResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreatedDataResponse | None = Field(None) + + +class TaskResultDataResponse(BaseModel): + status: str = Field(...) + outputs: list[str] = Field([]) + + +class TaskResultResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskResultDataResponse | None = Field(None) diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py new file mode 100644 index 000000000..c59fafd3b --- /dev/null +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -0,0 +1,178 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.wavespeed import ( + FlashVSRRequest, + TaskCreatedResponse, + TaskResultResponse, + SeedVR2ImageRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + poll_op, + sync_op, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_video_duration, + upload_images_to_comfyapi, + get_number_of_images, + download_url_to_image_tensor, +) + + +class WavespeedFlashVSRNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedFlashVSRNode", + display_name="FlashVSR Video Upscale", + category="api node/video/WaveSpeed", + description="Fast, high-quality video upscaler that " + "boosts resolution and restores clarity for low-resolution or blurry footage.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), + expr=""" + ( + $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; + { + "type":"usd", + "usd": $lookup($price_for_1sec, widgets.target_resolution), + "format":{"suffix": "/second", "approximate": true} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + target_resolution: str, + ) -> IO.NodeOutput: + validate_container_format_is_mp4(video) + validate_video_duration(video, min_duration=5, max_duration=60 * 10) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), + response_model=TaskCreatedResponse, + data=FlashVSRRequest( + target_resolution=target_resolution.lower(), + video=await upload_video_to_comfyapi(cls, video), + duration=video.get_duration(), + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) + + +class WavespeedImageUpscaleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedImageUpscaleNode", + display_name="WaveSpeed Image Upscale", + category="api node/image/WaveSpeed", + description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", + inputs=[ + IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), + IO.Image.Input("image"), + IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := {"seedvr2": 0.01, "ultimate": 0.06}; + {"type":"usd", "usd": $lookup($prices, widgets.model)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + target_resolution: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if model == "SeedVR2": + model_path = "seedvr2/image" + else: + model_path = "ultimate-image-upscaler" + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), + response_model=TaskCreatedResponse, + data=SeedVR2ImageRequest( + target_resolution=target_resolution.lower(), + image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) + + +class WavespeedExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + WavespeedFlashVSRNode, + WavespeedImageUpscaleNode, + ] + + +async def comfy_entrypoint() -> WavespeedExtension: + return WavespeedExtension() From 965d0ed509ce46a3328c342aee23a234ba6e4f88 Mon Sep 17 00:00:00 2001 From: Ivan Zorin Date: Wed, 21 Jan 2026 01:44:28 +0200 Subject: [PATCH 24/58] fix: remove normalization of audio in LTX Mel spectrogram creation (#11990) For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation. This aligs inference with training and prevents loud audio from being attenuated. --- comfy/ldm/lightricks/vae/audio_vae.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 29d9e6c29..55a074661 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -103,20 +103,10 @@ class AudioPreprocessor: return waveform return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) - @staticmethod - def normalize_amplitude( - waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 - ) -> torch.Tensor: - waveform = waveform - waveform.mean(dim=2, keepdim=True) - peak = torch.max(torch.abs(waveform)) + eps - scale = peak.clamp(max=max_amplitude) / peak - return waveform * scale - def waveform_to_mel( self, waveform: torch.Tensor, waveform_sample_rate: int, device ) -> torch.Tensor: waveform = self.resample(waveform, waveform_sample_rate) - waveform = self.normalize_amplitude(waveform) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.target_sample_rate, From c4a14df9a35336dbfff096683c5015ce726c269d Mon Sep 17 00:00:00 2001 From: Mylo <36931363+gitmylo@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:46:11 +0100 Subject: [PATCH 25/58] Dynamically detect chroma radiance patch size (#11991) --- comfy/model_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 42884f797..dad206a2f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 - dit_config["patch_size"] = 16 + dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1) dit_config["nerf_hidden_size"] = 64 dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 From e755268e7b7843695f52b87595afcb09c1e9fd87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 20 Jan 2026 20:08:31 -0800 Subject: [PATCH 26/58] Config for Qwen 3 0.6B model. (#11998) --- comfy/text_encoders/llama.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 331a30f61..3080a3e09 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -77,6 +77,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_06BConfig: + vocab_size: int = 151936 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -641,6 +663,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() From 0fc15700be9b555f351034942b5bd7243bdf6bcc Mon Sep 17 00:00:00 2001 From: Markury Date: Tue, 20 Jan 2026 23:18:33 -0500 Subject: [PATCH 27/58] Add LyCoris LoKr MLP layer support for Flux2 (#11997) --- comfy/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 5e79fb449..d97d753e6 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -611,6 +611,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", + "ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr + "ff.linear_in.bias": "img_mlp.0.bias", + "ff.linear_out.weight": "img_mlp.2.weight", + "ff.linear_out.bias": "img_mlp.2.bias", + "ff_context.linear_in.weight": "txt_mlp.0.weight", + "ff_context.linear_in.bias": "txt_mlp.0.bias", + "ff_context.linear_out.weight": "txt_mlp.2.weight", + "ff_context.linear_out.bias": "txt_mlp.2.bias", "attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", From 451af7015435df22e6313ae79f25fe2ef336a96d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 21 Jan 2026 14:03:45 +0200 Subject: [PATCH 28/58] fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002) --- comfy_api_nodes/nodes_vidu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 8edb02f39..b9114c4bb 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -703,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): "subjects", template=IO.Autogrow.TemplateNames( IO.Image.Input("reference_images"), - names=["subject1", "subject2", "subject3"], + names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"], min=1, ), tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " @@ -738,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): control_after_generate=True, ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), - IO.Combo.Input("resolution", options=["720p"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), IO.Combo.Input( "movement_amplitude", options=["auto", "small", "medium", "large"], From bdeac8897e522b9637a6a427fdc8a50a6abd6b20 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 21 Jan 2026 15:36:02 -0800 Subject: [PATCH 29/58] feat: Add search_aliases field to node schema (#12010) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add search_aliases field to node schema Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate). Changes: - Add `search_aliases: list[str]` to V3 Schema - Add `SEARCH_ALIASES` support for V1 nodes - Include field in `/object_info` response - Add aliases to high-priority core nodes V1 usage: ```python class MyNode: SEARCH_ALIASES = ["alt name", "synonym"] ``` V3 usage: ```python io.Schema( node_id="MyNode", search_aliases=["alt name", "synonym"], ... ) ``` ## Related PRs - Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this) - Docs: Comfy-Org/docs#XXXX (draft - merge after stable) * Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1 --- comfy_api/latest/_io.py | 4 ++++ comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_preview_any.py | 1 + comfy_extras/nodes_string.py | 1 + comfy_extras/nodes_upscale_model.py | 1 + nodes.py | 15 +++++++++++++++ server.py | 2 ++ 7 files changed, 25 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4969d3506..a60020ca8 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1249,6 +1249,7 @@ class NodeInfoV1: experimental: bool=None api_node: bool=None price_badge: dict | None = None + search_aliases: list[str]=None @dataclass class NodeInfoV3: @@ -1346,6 +1347,8 @@ class Schema: hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" + search_aliases: list[str] = field(default_factory=list) + """Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming.""" is_input_list: bool = False """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. @@ -1483,6 +1486,7 @@ class Schema: api_node=self.is_api_node, python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, + search_aliases=self.search_aliases if self.search_aliases else None, ) return info diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 2e559c35c..6011275d6 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -550,6 +550,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) ], diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 139b07c93..91502ebf2 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,6 +16,7 @@ class PreviewAny(): OUTPUT_NODE = True CATEGORY = "utils" + SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 571d89f62..a2d5f0d94 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode): node_id="StringConcatenate", display_name="Concatenate", category="utils/string", + search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ed587851c..97b9e948d 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode): node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", category="image/upscaling", + search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), io.Image.Input("image"), diff --git a/nodes.py b/nodes.py index ea5d6e525..67b61dcfe 100644 --- a/nodes.py +++ b/nodes.py @@ -70,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC): CATEGORY = "conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." + SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] def encode(self, clip, text): if clip is None: @@ -86,6 +87,7 @@ class ConditioningCombine: FUNCTION = "combine" CATEGORY = "conditioning" + SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) @@ -294,6 +296,7 @@ class VAEDecode: CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." + SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] def decode(self, vae, samples): latent = samples["samples"] @@ -346,6 +349,7 @@ class VAEEncode: FUNCTION = "encode" CATEGORY = "latent" + SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): t = vae.encode(pixels) @@ -581,6 +585,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." + SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -667,6 +672,7 @@ class LoraLoader: CATEGORY = "loaders" DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." + SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] def load_lora(self, model, clip, lora_name, strength_model, strength_clip): if strength_model == 0 and strength_clip == 0: @@ -814,6 +820,7 @@ class ControlNetLoader: FUNCTION = "load_controlnet" CATEGORY = "loaders" + SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -890,6 +897,7 @@ class ControlNetApplyAdvanced: FUNCTION = "apply_controlnet" CATEGORY = "conditioning/controlnet" + SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): if strength == 0: @@ -1200,6 +1208,7 @@ class EmptyLatentImage: CATEGORY = "latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." + SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) @@ -1540,6 +1549,7 @@ class KSampler: CATEGORY = "sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." + SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) @@ -1604,6 +1614,7 @@ class SaveImage: CATEGORY = "image" DESCRIPTION = "Saves the input images to your ComfyUI output directory." + SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"] def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -1640,6 +1651,8 @@ class PreviewImage(SaveImage): self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 + SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"] + @classmethod def INPUT_TYPES(s): return {"required": @@ -1658,6 +1671,7 @@ class LoadImage: } CATEGORY = "image" + SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" @@ -1810,6 +1824,7 @@ class ImageScale: FUNCTION = "upscale" CATEGORY = "image/upscaling" + SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"] def upscale(self, image, upscale_method, width, height, crop): if width == 0 and height == 0: diff --git a/server.py b/server.py index 04a577488..1888745b7 100644 --- a/server.py +++ b/server.py @@ -682,6 +682,8 @@ class PromptServer(): if hasattr(obj_class, 'API_NODE'): info['api_node'] = obj_class.API_NODE + + info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', []) return info @routes.get("/object_info") From abe2ec26a61ff670b9c0e71e4821c873368c8728 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:44:28 -0800 Subject: [PATCH 30/58] Support the Anima model. (#12012) --- comfy/ldm/anima/model.py | 202 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 22 ++++ comfy/model_detection.py | 2 + comfy/sd.py | 7 ++ comfy/supported_models.py | 33 +++++- comfy/text_encoders/anima.py | 61 +++++++++++ 6 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/anima/model.py create mode 100644 comfy/text_encoders/anima.py diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py new file mode 100644 index 000000000..2e6ed58fa --- /dev/null +++ b/comfy/ldm/anima/model.py @@ -0,0 +1,202 @@ +from comfy.ldm.cosmos.predict2 import MiniTrainDIT +import torch +from torch import nn +import torch.nn.functional as F + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb(key_states, cos, sin) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class TransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + x = self.in_proj(self.embed(target_input_ids)) + context = source_hidden_states + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class Anima(MiniTrainDIT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids): + if text_ids is not None: + return self.llm_adapter(text_embeds, text_ids) + else: + return text_embeds diff --git a/comfy/model_base.py b/comfy/model_base.py index 28ba2643e..1d57562cc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -49,6 +49,7 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model +import comfy.ldm.anima.model import comfy.model_management import comfy.patcher_extension @@ -1147,6 +1148,27 @@ class CosmosPredict2(BaseModel): sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) +class Anima(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + t5xxl_ids = kwargs.get("t5xxl_ids", None) + t5xxl_weights = kwargs.get("t5xxl_weights", None) + device = kwargs["device"] + if cross_attn is not None: + if t5xxl_ids is not None: + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) + if t5xxl_weights is not None: + cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + + if cross_attn.shape[1] < 512: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dad206a2f..b29a033cc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -550,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} dit_config["image_model"] = "cosmos_predict2" + if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "anima" dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 diff --git a/comfy/sd.py b/comfy/sd.py index 77700dfd3..f7f6a44a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -57,6 +57,7 @@ import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie +import comfy.text_encoders.anima import comfy.model_patcher import comfy.lora @@ -1048,6 +1049,7 @@ class TEModel(Enum): GEMMA_3_12B = 18 JINA_CLIP_2 = 19 QWEN3_8B = 20 + QWEN3_06B = 21 def detect_te_model(sd): @@ -1093,6 +1095,8 @@ def detect_te_model(sd): return TEModel.QWEN3_2B elif weight.shape[0] == 4096: return TEModel.QWEN3_8B + elif weight.shape[0] == 1024: + return TEModel.QWEN3_06B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1233,6 +1237,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper + elif te_model == TEModel.QWEN3_06B: + clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c8a7f6efb..70abebf46 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.anima from . import supported_models_base from . import latent_formats @@ -992,6 +993,36 @@ class CosmosT2IPredict2(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class Anima(supported_models_base.BASE): + unet_config = { + "image_model": "anima", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Anima(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", @@ -1551,6 +1582,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py new file mode 100644 index 000000000..41f95bcb6 --- /dev/null +++ b/comfy/text_encoders/anima.py @@ -0,0 +1,61 @@ +from transformers import Qwen2Tokenizer, T5TokenizerFast +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch + + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class AnimaTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) + out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.t5xxl.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class AnimaTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out = super().encode_token_weights(token_weight_pairs) + out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int) + out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0]))) + return out + +def te(dtype_llama=None, llama_quantization_metadata=None): + class AnimaTEModel_(AnimaTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return AnimaTEModel_ From f09904720dc8b56bc6823ebdaf5de69465448e46 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 Jan 2026 20:01:35 -0800 Subject: [PATCH 31/58] Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) --- comfy_extras/nodes_easycache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 11b23ffdb..90d730df6 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs): do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) # if first cond marked this step for skipping, skip it and use appropriate cached values - if easycache.skip_current_step: + if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") return easycache.apply_cache_diff(x, uuids) @@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm easycache.cumulative_change_rate += approx_output_change_rate - if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values @@ -240,6 +242,9 @@ class EasyCacheHolder: return to_return.clone() return to_return + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): if self.first_cond_uuid in uuids: self.total_steps_skipped += 1 From 3365ad18a5e0c86b23c6272e5adcedd333fc45cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:03:51 +0200 Subject: [PATCH 32/58] Support LTX2 tiny vae (taeltx_2) (#11929) --- comfy/sd.py | 5 ++--- comfy/taesd/taehv.py | 53 ++++++++++++++++++++++++++++---------------- latent_preview.py | 2 +- nodes.py | 2 +- 4 files changed, 38 insertions(+), 24 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f7f6a44a0..ce7e6bcff 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -636,14 +636,13 @@ class VAE: self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 0e5f9a378..6c06ce19d 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ class TAEHV(nn.Module): self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,52 @@ class TAEHV(nn.Module): self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: - x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) + x = F.pixel_unshuffle(x, self.patch_size) + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) if self.patch_size > 1: diff --git a/latent_preview.py b/latent_preview.py index d52e3f7a1..a9d777661 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,7 +11,7 @@ import logging default_preview_method = args.preview_method MAX_PREVIEW_RESOLUTION = args.preview_size -VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] def preview_to_image(latent_image, do_scale=True): if do_scale: diff --git a/nodes.py b/nodes.py index 67b61dcfe..8864fda60 100644 --- a/nodes.py +++ b/nodes.py @@ -707,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: - video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod def vae_list(s): From 245f6139b65899112d11ff294d36a820f2d69496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:05:06 +0200 Subject: [PATCH 33/58] More targeted embedding_connector loading for LTX2 text encoder (#11992) Reduces errors --- comfy/text_encoders/lt.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c33c77db7..e49161964 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -118,9 +118,18 @@ class LTXAVTEModel(torch.nn.Module): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - missing, unexpected = self.load_state_dict(sdo, strict=False) - missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model - return (missing, unexpected) + + missing_all = [] + unexpected_all = [] + + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} + if component_sd: + missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing_all.extend([f"{prefix}{k}" for k in missing]) + unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) + + return (missing_all, unexpected_all) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 From 16b9aabd52c3b81b365fbf562bbcc4528111ef6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:09:48 +0200 Subject: [PATCH 34/58] Support Multi/InfiniteTalk (#10179) * re-init * Update model_multitalk.py * whitespace... * Update model_multitalk.py * remove print * this is redundant * remove import * Restore preview functionality * Move block_idx to transformer_options * Remove LoopingSamplerCustomAdvanced * Remove looping functionality, keep extension functionality * Update model_multitalk.py * Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn * Chunk attention map calculation for multiple speakers to reduce peak VRAM usage * Update model_multitalk.py * Add ModelPatch type back * Fix for latest upstream * Use DynamicCombo for cleaner node Basically just so that single_speaker mode hides mask inputs and 2nd audio input * Update nodes_wan.py --- comfy/ldm/wan/model.py | 17 +- comfy/ldm/wan/model_multitalk.py | 500 ++++++++++++++++++++++++++++++ comfy_api/latest/_io.py | 3 +- comfy_extras/nodes_model_patch.py | 41 +++ comfy_extras/nodes_wan.py | 169 +++++++++- 5 files changed, 727 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/wan/model_multitalk.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..ea123acb4 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) return x @@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 + patches = transformer_options.get("patches", {}) + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module): # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: + x = p({"x": x, "transformer_options": transformer_options}) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -488,7 +501,7 @@ class WanModel(torch.nn.Module): self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for _ in range(num_layers) + for i in range(num_layers) ]) # head @@ -541,6 +554,7 @@ class WanModel(torch.nn.Module): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -738,6 +752,7 @@ class VaceWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py new file mode 100644 index 000000000..c9dd98c4d --- /dev/null +++ b/comfy/ldm/wan/model_multitalk.py @@ -0,0 +1,500 @@ +import torch +from einops import rearrange, repeat +import comfy +from comfy.ldm.modules.attention import optimized_attention + + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q.transpose(1, 2) * scale + + B, H, x_seqlens, K = visual_q.shape + + x_ref_attn_maps = [] + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens + x_ref_attn_maps.append(x_ref_attnmap) + + del visual_q, ref_k + + return torch.cat(x_ref_attn_maps, dim=0) + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map( + visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_target_masks + ) + x_ref_attn_maps += x_ref_attn_maps_perhead + + return x_ref_attn_maps / split_num + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + source_min, source_max = source_range + new_min, new_max = target_range + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def get_audio_embeds(encoded_audio, audio_start, audio_end): + audio_embs = [] + human_num = len(encoded_audio) + audio_frames = encoded_audio[0].shape[0] + + indices = (torch.arange(4 + 1) - 2) * 1 + + for human_idx in range(human_num): + if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence + pad_len = audio_end - audio_frames + pad_shape = list(encoded_audio[human_idx].shape) + pad_shape[0] = pad_len + pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1))) + encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0) + else: + encoded_audio_in = encoded_audio[human_idx] + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1) + audio_emb = encoded_audio_in[center_indices].unsqueeze(0) + audio_embs.append(audio_emb) + + return torch.cat(audio_embs, dim=0) + + +def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end): + audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end) + + first_frame_audio_emb_s = audio_embs[:, :1, ...] + latter_frame_audio_emb = audio_embs[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4) + + middle_index = audio_proj.seq_len // 2 + + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + audio_emb = torch.cat(audio_emb.split(1), dim=2) + + return audio_emb + + +class RotaryPositionalEmbedding1D(torch.nn.Module): + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + def precompute_freqs_cis_1d(self, pos_indices): + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + +class SingleStreamAttention(torch.nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + expected_tokens = N_t * N_h * N_w + actual_tokens = x.shape[1] + x_extra = None + + if actual_tokens != expected_tokens: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + + B = x.shape[0] + S = N_h * N_w + x = x.view(B * N_t, S, self.dim) + + # get q for hidden_state + q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) + + # get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim) + kv = self.kv_linear(encoder_hidden_states) + encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) + + #print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128]) + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # linear transform + x = self.proj(x.reshape(B * N_t, S, self.dim)) + x = x.view(B, N_t * S, self.dim) + + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + +class SingleStreamMultiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + class_range: int = 24, + class_interval: int = 4, + device=None, dtype=None, operations=None + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + device=device, + dtype=dtype, + operations=operations + ) + + # Rotary-embedding layout parameters + self.class_interval = class_interval + self.class_range = class_range + self.max_humans = self.class_range // self.class_interval + + # Constant bucket used for background tokens + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None + ) -> torch.Tensor: + encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device) + human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1 + # Single-speaker fall-through + if human_num <= 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, N_h, N_w = shape + + x_extra = None + if x.shape[0] * N_t != encoder_hidden_states.shape[0]: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # Query projection + B, N, C = x.shape + q = self.q_linear(x) + q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Use `class_range` logic for 2 speakers + rope_h1 = (0, self.class_interval) + rope_h2 = (self.class_range - self.class_interval, self.class_range) + rope_bak = int(self.class_range // 2) + + # Normalize and scale attention maps for each speaker + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) + back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) + + # Token-wise speaker dominance + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] + + # Apply rotary to Q + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Keys / Values + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + encoder_k, encoder_v = encoder_kv.unbind(0) + + # Rotary for keys – assign centre of each speaker bucket to its context tokens + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) + per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 + per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 + encoder_pos = torch.cat([per_frame] * N_t, dim=0) + + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Final attention + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # Linear projection + x = x.reshape(B, N, C) + x = self.proj(x) + + # Restore original layout + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + + +class MultiTalkAudioProjModel(torch.nn.Module): + def __init__( + self, + seq_len: int = 5, + seq_len_vf: int = 12, + blocks: int = 12, + channels: int = 768, + intermediate_dim: int = 512, + out_dim: int = 768, + context_tokens: int = 32, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.out_dim = out_dim + + # define multiple linear layers + self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype) + self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype) + self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype) + self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype) + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanMultiTalkAttentionBlock(torch.nn.Module): + def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): + super().__init__() + self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations) + self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) + + +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + +class MultiTalkCrossAttnPatch: + def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): + self.model_patch = model_patch + self.audio_scale = audio_scale + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_index", None) + x = kwargs["x"] + if block_idx is None: + return torch.zeros_like(x) + + audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) + + norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) + x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( + norm_x, audio_embeds.to(x.dtype), + shape=transformer_options["grid_sizes"], + x_ref_attn_map=x_ref_attn_map + ) + x = x + x_audio * self.audio_scale + return x + + def models(self): + return [self.model_patch] + +class MultiTalkApplyModelWrapper: + def __init__(self, init_latents): + self.init_latents = init_latents + + def __call__(self, executor, x, *args, **kwargs): + x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x) + samples = executor(x, *args, **kwargs) + return samples + + +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent + self.model_patch = model_patch + self.is_extend = is_extend + + def __call__(self, executor, *args, **kwargs): + model_patcher = executor.class_obj.model_patcher + model_options = executor.class_obj.model_options + process_latent_in = model_patcher.model.process_latent_in + + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] + + # run the sampling process + result = executor(*args, **kwargs) + + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + + return result + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) + return self diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a60020ca8..2ec8d6e4b 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO): Type = Any @comfytype(io_type="MODEL_PATCH") -class MODEL_PATCH(ComfyTypeIO): +class ModelPatch(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") @@ -2038,6 +2038,7 @@ __all__ = [ "ControlNet", "Vae", "Model", + "ModelPatch", "ClipVision", "ClipVisionOutput", "AudioEncoder", diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index f66d28fc9..82c4754a3 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,6 +7,7 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel class BlockWiseControlBlock(torch.nn.Module): @@ -257,6 +258,14 @@ class ModelPatchLoader: if torch.count_nonzero(ref_weight) == 0: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + elif "audio_proj.proj1.weight" in sd: + model = MultiTalkModelPatch( + audio_window=5, context_tokens=32, vae_scale=4, + in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0], + intermediate_dim=sd["audio_proj.proj1.weight"].shape[0], + out_dim=sd["audio_proj.norm.weight"].shape[0], + device=comfy.model_management.unet_offload_device(), + operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -524,6 +533,38 @@ class USOStyleReference: return (model_patched,) +class MultiTalkModelPatch(torch.nn.Module): + def __init__( + self, + audio_window: int = 5, + intermediate_dim: int = 512, + in_dim: int = 5120, + out_dim: int = 768, + context_tokens: int = 32, + vae_scale: int = 4, + num_layers: int = 40, + + device=None, dtype=None, operations=None + ): + super().__init__() + self.audio_proj = MultiTalkAudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + out_dim=out_dim, + context_tokens=context_tokens, + device=device, + dtype=dtype, + operations=operations + ) + self.blocks = torch.nn.ModuleList( + [ + WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d32aad98e..90deb0077 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -8,9 +8,10 @@ import comfy.latent_formats import comfy.clip_vision import json import numpy as np -from typing import Tuple +from typing import Tuple, TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import logging class WanImageToVideo(io.ComfyNode): @classmethod @@ -1288,6 +1289,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features +class WanInfiniteTalkToVideo(io.ComfyNode): + class DCValues(TypedDict): + mode: str + audio_encoder_output_2: io.AudioEncoderOutput.Type + mask: io.Mask.Type + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanInfiniteTalkToVideo", + category="conditioning/video_models", + inputs=[ + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("single_speaker", []), + io.DynamicCombo.Option("two_speakers", [ + io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), + io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), + io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), + ]), + ]), + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.AudioEncoderOutput.Input("audio_encoder_output_1"), + io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), + io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), + io.Image.Input("previous_frames", optional=True), + ], + outputs=[ + io.Model.Output(display_name="model"), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_image"), + ], + ) + + @classmethod + def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, + start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: + + if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: + raise ValueError("Not enough previous frames provided.") + + if mode["mode"] == "two_speakers": + audio_encoder_output_2 = mode["audio_encoder_output_2"] + mask_1 = mode["mask_1"] + mask_2 = mode["mask_2"] + + if audio_encoder_output_2 is not None: + if mask_1 is None or mask_2 is None: + raise ValueError("Masks must be provided if two audio encoder outputs are used.") + + ref_masks = None + if mask_1 is not None and mask_2 is not None: + if audio_encoder_output_2 is None: + raise ValueError("Second audio encoder output must be provided if two masks are used.") + ref_masks = torch.cat([mask_1, mask_2]) + + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + model_patched = model.clone() + + encoded_audio_list = [] + seq_lengths = [] + + for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]: + if audio_encoder_output is None: + continue + all_layers = audio_encoder_output["encoded_audio_all_layers"] + encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512] + encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512] + encoded_audio_list.append(encoded_audio) + seq_lengths.append(encoded_audio.shape[0]) + + # Pad / combine depending on multi_audio_type + multi_audio_type = "add" + if len(encoded_audio_list) > 1: + if multi_audio_type == "para": + max_len = max(seq_lengths) + padded = [] + for emb in encoded_audio_list: + if emb.shape[0] < max_len: + pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype) + emb = torch.cat([emb, pad], dim=0) + padded.append(emb) + encoded_audio_list = padded + elif multi_audio_type == "add": + total_len = sum(seq_lengths) + full_list = [] + offset = 0 + for emb, seq_len in zip(encoded_audio_list, seq_lengths): + full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype) + full[offset:offset+seq_len] = emb + full_list.append(full) + offset += seq_len + encoded_audio_list = full_list + + token_ref_target_masks = None + if ref_masks is not None: + token_ref_target_masks = torch.nn.functional.interpolate( + ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0] + token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1) + + # when extending from previous frames + if previous_frames is not None: + motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_offset = previous_frames.shape[0] - motion_frame_count + + audio_start = frame_offset + audio_end = audio_start + length + logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}") + + motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) + trim_image = motion_frame_count + else: + audio_start = trim_image = 0 + audio_end = length + motion_frames_latent = concat_latent_image[:, :, :1] + + audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) + model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed + + # add outer sample wrapper + model_patched.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, + "infinite_talk_outer_sample", + InfiniteTalkOuterSampleWrapper( + motion_frames_latent, + model_patch, + is_extend=previous_frames is not None, + )) + # add cross-attention patch + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch") + if token_ref_target_masks is not None: + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch") + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1307,6 +1473,7 @@ class WanExtension(ComfyExtension): WanHuMoImageToVideo, WanAnimateToVideo, Wan22ImageToVideoLatent, + WanInfiniteTalkToVideo, ] async def comfy_entrypoint() -> WanExtension: From 72f6be1690868af852a624084a949b785fc056ea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:42:04 +0200 Subject: [PATCH 35/58] chore(api-nodes): rename BriaImage and OpenAIGImage nodes (#12022) --- comfy_api_nodes/nodes_bria.py | 2 +- comfy_api_nodes/nodes_openai.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 72a3055a7..d3a52bc1b 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -24,7 +24,7 @@ class BriaImageEditNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="BriaImageEditNode", - display_name="Bria Image Edit", + display_name="Bria FIBO Image Edit", category="api node/image/Bria", description="Edit images using Bria latest model", inputs=[ diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a12acc06b..f05aaab7b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -364,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1", + display_name="OpenAI GPT Image 1.5", category="api node/image/OpenAI", - description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", + description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( "prompt", @@ -429,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "model", options=["gpt-image-1", "gpt-image-1.5"], + default="gpt-image-1.5", optional=True, ), ], From 8490eedadfc0ab00cb131bab681059163c2ebbcd Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Thu, 22 Jan 2026 12:46:56 -0500 Subject: [PATCH 36/58] add ply & 3dgs format in 3d node (#11474) --- comfy_extras/nodes_load_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 545588ef8..a16b8c8f3 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode): files = [ normalize_path(str(file_path.relative_to(base_path))) for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'} ] return IO.Schema( node_id="Load3D", From 0fd1b787360a70dd37aa14089ccb5fc1820f9e17 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 22 Jan 2026 13:54:18 -0800 Subject: [PATCH 37/58] Reduce LTX2 VAE VRAM consumption (#12028) * causal_video_ae: Remove attention ResNet This attention_head_dim argument does not exist on this constructor so this is dead code. Remove as generic attention mid VAE conflicts with temporal roll. * ltx-vae: consoldate causal/non-causal code paths * ltx-vae: add cache rolling adder * ltx-vae: use cached adder for resnet * ltx-vae: Implement rolling VAE Implement a temporal rolling VAE for the LTX2 VAE. Usually when doing temporal rolling VAEs you can just chunk on time relying on causality and cache behind you as you go. The LTX VAE is however non-causal. So go whole hog and implement per layer run ahead and backpressure between the decoder layers using recursive state beween the layers. Operations are ammended with temporal_cache_state{} which they can use to hold any state then need for partial execution. Convolutions cache their inputs behind the up to N-1 frames, and skip connections need to cache the mismatch between convolution input and output that happens due to missing future (non-causal) input. Each call to run_up() processes a layer accross a range on input that may or may not be complete. It goes depth first to process as much as possible to try and digest frames to the final output ASAP. If layers run out of input due to convolution losses, they simply return without action effectively applying back-pressure to the earlier layers. As the earlier layers do more work and caller deeper, the partial states are reconciled and output continues to digest depth first as much as possible. Chunking is done using a size quota rather than a fixed frame length and any layer can initiate chunking, and multiple layers can chunk at different granulatiries. This remove the old limitation of always having to process 1 latent frame to entirety and having to hold 8 full decoded frames as the VRAM peak. --- comfy/ldm/lightricks/vae/causal_conv3d.py | 43 +++-- .../vae/causal_video_autoencoder.py | 176 +++++++++++++----- comfy/ldm/modules/diffusionmodules/model.py | 5 +- 3 files changed, 160 insertions(+), 64 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 70d612e86..b8341edbc 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -1,11 +1,11 @@ from typing import Tuple, Union +import threading import torch import torch.nn as nn import comfy.ops ops = comfy.ops.disable_weight_init - class CausalConv3d(nn.Module): def __init__( self, @@ -42,23 +42,34 @@ class CausalConv3d(nn.Module): padding_mode=spatial_padding_mode, groups=groups, ) + self.temporal_cache_state={} def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x + tid = threading.get_ident() + + cached, is_end = self.temporal_cache_state.get(tid, (None, False)) + if cached is None: + padding_length = self.time_kernel_size - 1 + if not causal: + padding_length = padding_length // 2 + if x.shape[2] == 0: + return x + cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1)) + pieces = [ cached, x ] + if is_end and not causal: + pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + + needs_caching = not is_end + if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + + x = torch.cat(pieces, dim=2) + + if needs_caching: + self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + + return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] @property def weight(self): diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 75ed069ad..cbfdf412d 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,4 +1,5 @@ from __future__ import annotations +import threading import torch from torch import nn from functools import partial @@ -6,12 +7,35 @@ import math from einops import rearrange from typing import List, Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd +from .causal_conv3d import CausalConv3d from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops +from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def mark_conv3d_ended(module): + tid = threading.get_ident() + for _, m in module.named_modules(): + if isinstance(m, CausalConv3d): + current = m.temporal_cache_state.get(tid, (None, False)) + m.temporal_cache_state[tid] = (current[0], True) + +def split2(tensor, split_point, dim=2): + return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) + +def add_exchange_cache(dest, cache_in, new_input, dim=2): + if dest is not None: + if cache_in is not None: + cache_to_dest = min(dest.shape[dim], cache_in.shape[dim]) + lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim) + lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim) + lead_in_dest.add_(lead_in_source) + body, new_input = split2(new_input, dest.shape[dim], dim) + dest.add_(body) + return torch_cat_if_needed([cache_in, new_input], dim=dim) + class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. @@ -205,7 +229,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) @@ -254,6 +278,22 @@ class Encoder(nn.Module): return sample + def forward(self, *args, **kwargs): + #No encoder support so just flag the end so it doesnt use the cache. + mark_conv3d_ended(self) + try: + return self.forward_orig(*args, **kwargs) + finally: + tid = threading.get_ident() + for _, module in self.named_modules(): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + + +MAX_CHUNK_SIZE=(128 * 1024 ** 2) class Decoder(nn.Module): r""" @@ -341,18 +381,6 @@ class Decoder(nn.Module): timestep_conditioning=timestep_conditioning, spatial_padding_mode=spatial_padding_mode, ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) block = ResnetBlock3D( @@ -428,8 +456,9 @@ class Decoder(nn.Module): ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - def forward( + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, @@ -437,6 +466,7 @@ class Decoder(nn.Module): r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) checkpoint_fn = ( @@ -445,24 +475,12 @@ class Decoder(nn.Module): else lambda x: x ) - scaled_timestep = None + timestep_shift_scale = None if self.timestep_conditioning: assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - sample = self.conv_norm_out(sample) - - if self.timestep_conditioning: embedded_timestep = self.last_time_embedder( timestep=scaled_timestep.flatten(), resolution=None, @@ -483,16 +501,62 @@ class Decoder(nn.Module): embedded_timestep.shape[-2], embedded_timestep.shape[-1], ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift + timestep_shift_scale = ada_values.unbind(dim=1) - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) + output = [] + + def run_up(idx, sample, ended): + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + output.append(sample) + return + + up_block = self.up_blocks[idx] + if (ended): + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) + + run_up(0, sample, True) + sample = torch.cat(output, dim=2) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) return sample + def forward(self, *args, **kwargs): + try: + return self.forward_orig(*args, **kwargs) + finally: + for _, module in self.named_modules(): + #ComfyUI doesn't thread this kind of stuff today, but just incase + #we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + class UNetMidBlock3D(nn.Module): """ @@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module): ) self.residual = residual self.out_channels_reduction_factor = out_channels_reduction_factor + self.temporal_cache_state = {} def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): + tid = threading.get_ident() + cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True)) + y = self.conv(x, causal=causal) + y = rearrange( + y, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv: + y = y[:, :, 1:, :, :] + drop_first_conv = False if self.residual: # Reshape and duplicate the input to match the output shape x_in = rearrange( @@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module): ) num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: + if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res: x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x + drop_first_res = False + + if y.shape[2] == 0: + y = None + + cached = add_exchange_cache(y, cached, x_in, dim=2) + self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res) + + else: + self.temporal_cache_state[tid] = (None, drop_first_conv, False) + + return y class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: @@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module): torch.randn(4, in_channels) / in_channels**0.5 ) + self.temporal_cache_state={} + def _feed_spatial_noise( self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor ) -> torch.FloatTensor: @@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module): input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states + tid = threading.get_ident() + cached = self.temporal_cache_state.get(tid, None) + cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2) + self.temporal_cache_state[tid] = cached - return output_tensor + return hidden_states def patchify(x, patch_size_hw, patch_size_t=1): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1ae3ef034..5a22ef030 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae(): import xformers.ops def torch_cat_if_needed(xl, dim): + xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: return torch.cat(xl, dim) - else: + elif len(xl) == 1: return xl[0] + else: + return None def get_timestep_embedding(timesteps, embedding_dim): """ From 09a2e67151c6753a0038f6e01f3c3d93fcc3ec98 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:20:48 -0800 Subject: [PATCH 38/58] Support loading flux 2 klein checkpoints saved with SaveCheckpoint. (#12033) --- comfy/supported_models.py | 20 +++++++++++++++++--- comfy/text_encoders/hunyuan_video.py | 8 +++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 70abebf46..45d913fa6 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -771,10 +771,24 @@ class Flux2(Flux): return out def clip_target(self, state_dict={}): - return None # TODO pref = self.text_encoder_key_prefix[0] - t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if len(detect) > 0: + detect["model_type"] = "qwen3_4b" + return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect)) + + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref)) + if len(detect) > 0: + detect["model_type"] = "qwen3_8b" + return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect)) + + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref)) + if len(detect) > 0: + if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict: + detect["pruned"] = True + return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect)) + + return None class GenmoMochi(supported_models_base.BASE): unet_config = { diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index a9a6c525e..2ddb4da60 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -10,9 +10,11 @@ import comfy.utils def llama_detect(state_dict, prefix=""): out = {} - t5_key = "{}model.norm.weight".format(prefix) - if t5_key in state_dict: - out["dtype_llama"] = state_dict[t5_key].dtype + norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)] + for norm_key in norm_keys: + if norm_key in state_dict: + out["dtype_llama"] = state_dict[norm_key].dtype + break quant = comfy.utils.detect_layer_quantization(state_dict, prefix) if quant is not None: From d7f3241bf6b11f67ada34c51097fbaad0c01124a Mon Sep 17 00:00:00 2001 From: Omri Marom <110098005+maromri@users.noreply.github.com> Date: Fri, 23 Jan 2026 03:02:31 +0200 Subject: [PATCH 39/58] qwen_image: propagate attention mask. (#11966) --- comfy/ldm/qwen_image/model.py | 11 ++++++++++- comfy/model_base.py | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 00c597535..6eb744286 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -170,8 +170,14 @@ class Attention(nn.Module): joint_query = apply_rope1(joint_query, image_rotary_emb) joint_key = apply_rope1(joint_key, image_rotary_emb) + if encoder_hidden_states_mask is not None: + attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) + attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask + else: + attn_mask = None + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, - attention_mask, transformer_options=transformer_options, + attn_mask, transformer_options=transformer_options, skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] @@ -430,6 +436,9 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask + if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask): + encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max + hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] diff --git a/comfy/model_base.py b/comfy/model_base.py index 1d57562cc..66e52864d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1578,6 +1578,9 @@ class QwenImage(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) From bbb8864778a93eb0fa60c76201383e2b5a63aa38 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 22 Jan 2026 18:36:58 -0800 Subject: [PATCH 40/58] add search aliases to all nodes (#12035) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add search_aliases field to node schema Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate). Changes: - Add `search_aliases: list[str]` to V3 Schema - Add `SEARCH_ALIASES` support for V1 nodes - Include field in `/object_info` response - Add aliases to high-priority core nodes V1 usage: ```python class MyNode: SEARCH_ALIASES = ["alt name", "synonym"] ``` V3 usage: ```python io.Schema( node_id="MyNode", search_aliases=["alt name", "synonym"], ... ) ``` ## Related PRs - Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this) - Docs: Comfy-Org/docs#XXXX (draft - merge after stable) * Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1 * feat: add SEARCH_ALIASES for core nodes (#12016) Add search aliases to 22 core nodes in nodes.py to improve node discoverability: - Checkpoint/model loaders: CheckpointLoader, DiffusersLoader - Conditioning nodes: ConditioningAverage, ConditioningSetArea, ConditioningSetMask, ConditioningZeroOut - Style nodes: StyleModelApply - Image nodes: LoadImageMask, LoadImageOutput, ImageBatch, ImageInvert, ImagePadForOutpaint - Latent nodes: LoadLatent, SaveLatent, LatentBlend, LatentComposite, LatentCrop, LatentFlip, LatentFromBatch, LatentUpscale, LatentUpscaleBy, RepeatLatentBatch * feat: add SEARCH_ALIASES for image, mask, and string nodes (#12017) Add search aliases to nodes in comfy_extras for better discoverability: - nodes_mask.py: mask manipulation nodes - nodes_images.py: image processing nodes - nodes_post_processing.py: post-processing effect nodes - nodes_string.py: string manipulation nodes - nodes_compositing.py: compositing nodes - nodes_morphology.py: morphological operation nodes - nodes_latent.py: latent space nodes Uses search_aliases parameter in io.Schema() for v3 nodes. * feat: add SEARCH_ALIASES for audio and video nodes (#12018) Add search aliases to audio and video nodes for better discoverability: - nodes_audio.py: audio loading, saving, and processing nodes - nodes_video.py: video loading and processing nodes - nodes_wan.py: WAN model nodes Uses search_aliases parameter in io.Schema() for v3 nodes. * feat: add SEARCH_ALIASES for model and misc nodes (#12019) Add search aliases to model-related and miscellaneous nodes: - Model nodes: nodes_model_merging.py, nodes_model_advanced.py, nodes_lora_extract.py - Sampler nodes: nodes_custom_sampler.py, nodes_align_your_steps.py - Control nodes: nodes_controlnet.py, nodes_attention_multiply.py, nodes_hooks.py - Training nodes: nodes_train.py, nodes_dataset.py - Utility nodes: nodes_logic.py, nodes_canny.py, nodes_differential_diffusion.py - Architecture-specific: nodes_sd3.py, nodes_pixart.py, nodes_lumina2.py, nodes_kandinsky5.py, nodes_hidream.py, nodes_fresca.py, nodes_hunyuan3d.py - Media nodes: nodes_load_3d.py, nodes_webcam.py, nodes_preview_any.py, nodes_wanmove.py Uses search_aliases parameter in io.Schema() for v3 nodes, SEARCH_ALIASES class attribute for legacy nodes. --- comfy_extras/nodes_align_your_steps.py | 1 + comfy_extras/nodes_attention_multiply.py | 1 + comfy_extras/nodes_audio.py | 14 +++++++ comfy_extras/nodes_canny.py | 1 + comfy_extras/nodes_compositing.py | 3 ++ comfy_extras/nodes_controlnet.py | 1 + comfy_extras/nodes_custom_sampler.py | 4 ++ comfy_extras/nodes_dataset.py | 6 +-- comfy_extras/nodes_differential_diffusion.py | 1 + comfy_extras/nodes_fresca.py | 1 + comfy_extras/nodes_hidream.py | 1 + comfy_extras/nodes_hooks.py | 5 +++ comfy_extras/nodes_hunyuan3d.py | 1 + comfy_extras/nodes_images.py | 16 +++++--- comfy_extras/nodes_kandinsky5.py | 1 + comfy_extras/nodes_latent.py | 10 +++++ comfy_extras/nodes_load_3d.py | 1 + comfy_extras/nodes_logic.py | 2 + comfy_extras/nodes_lora_extract.py | 1 + comfy_extras/nodes_lumina2.py | 1 + comfy_extras/nodes_mask.py | 12 ++++++ comfy_extras/nodes_model_advanced.py | 1 + comfy_extras/nodes_model_merging.py | 4 ++ comfy_extras/nodes_morphology.py | 3 ++ comfy_extras/nodes_pixart.py | 1 + comfy_extras/nodes_post_processing.py | 5 ++- comfy_extras/nodes_preview_any.py | 2 +- comfy_extras/nodes_sd3.py | 1 + comfy_extras/nodes_string.py | 10 +++++ comfy_extras/nodes_train.py | 2 + comfy_extras/nodes_video.py | 5 +++ comfy_extras/nodes_wan.py | 2 + comfy_extras/nodes_wanmove.py | 1 + comfy_extras/nodes_webcam.py | 1 + nodes.py | 41 ++++++++++++++++++++ 35 files changed, 152 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index edd5dadd4..4fc511d2c 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="AlignYourStepsScheduler", + search_aliases=["AYS scheduler"], category="sampling/custom_sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index c0e494c2a..67c4e2ed0 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -71,6 +71,7 @@ class CLIPAttentionMultiply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CLIPAttentionMultiply", + search_aliases=["clip attention scale", "text encoder attention"], category="_for_testing/attention_experiments", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 15b3aa401..271b75fbd 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -69,6 +69,7 @@ class VAEEncodeAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEEncodeAudio", + search_aliases=["audio to latent"], display_name="VAE Encode Audio", category="latent/audio", inputs=[ @@ -97,6 +98,7 @@ class VAEDecodeAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEDecodeAudio", + search_aliases=["latent to audio"], display_name="VAE Decode Audio", category="latent/audio", inputs=[ @@ -122,6 +124,7 @@ class SaveAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudio", + search_aliases=["export flac"], display_name="Save Audio (FLAC)", category="audio", inputs=[ @@ -146,6 +149,7 @@ class SaveAudioMP3(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudioMP3", + search_aliases=["export mp3"], display_name="Save Audio (MP3)", category="audio", inputs=[ @@ -173,6 +177,7 @@ class SaveAudioOpus(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudioOpus", + search_aliases=["export opus"], display_name="Save Audio (Opus)", category="audio", inputs=[ @@ -200,6 +205,7 @@ class PreviewAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="PreviewAudio", + search_aliases=["play audio"], display_name="Preview Audio", category="audio", inputs=[ @@ -259,6 +265,7 @@ class LoadAudio(IO.ComfyNode): files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) return IO.Schema( node_id="LoadAudio", + search_aliases=["import audio", "open audio", "audio file"], display_name="Load Audio", category="audio", inputs=[ @@ -296,6 +303,7 @@ class RecordAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="RecordAudio", + search_aliases=["microphone input", "audio capture", "voice input"], display_name="Record Audio", category="audio", inputs=[ @@ -320,6 +328,7 @@ class TrimAudioDuration(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TrimAudioDuration", + search_aliases=["cut audio", "audio clip", "shorten audio"], display_name="Trim Audio Duration", description="Trim audio tensor into chosen time range.", category="audio", @@ -372,6 +381,7 @@ class SplitAudioChannels(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SplitAudioChannels", + search_aliases=["stereo to mono"], display_name="Split Audio Channels", description="Separates the audio into left and right channels.", category="audio", @@ -472,6 +482,7 @@ class AudioConcat(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioConcat", + search_aliases=["join audio", "combine audio", "append audio"], display_name="Audio Concat", description="Concatenates the audio1 to audio2 in the specified direction.", category="audio", @@ -519,6 +530,7 @@ class AudioMerge(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioMerge", + search_aliases=["mix audio", "overlay audio", "layer audio"], display_name="Audio Merge", description="Combine two audio tracks by overlaying their waveforms.", category="audio", @@ -579,6 +591,7 @@ class AudioAdjustVolume(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioAdjustVolume", + search_aliases=["audio gain", "loudness", "audio level"], display_name="Audio Adjust Volume", category="audio", inputs=[ @@ -614,6 +627,7 @@ class EmptyAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="EmptyAudio", + search_aliases=["blank audio"], display_name="Empty Audio", category="audio", inputs=[ diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 576f3640a..6e0fadca5 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -10,6 +10,7 @@ class Canny(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Canny", + search_aliases=["edge detection", "outline", "contour detection", "line art"], category="image/preprocessors", inputs=[ io.Image.Input("image"), diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index e4e4e1cbc..3bc9fccb3 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -109,6 +109,7 @@ class PorterDuffImageComposite(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PorterDuffImageComposite", + search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"], display_name="Porter-Duff Image Composite", category="mask/compositing", inputs=[ @@ -165,6 +166,7 @@ class SplitImageWithAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitImageWithAlpha", + search_aliases=["extract alpha", "separate transparency", "remove alpha"], display_name="Split Image with Alpha", category="mask/compositing", inputs=[ @@ -188,6 +190,7 @@ class JoinImageWithAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="JoinImageWithAlpha", + search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"], display_name="Join Image with Alpha", category="mask/compositing", inputs=[ diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index e835feed7..0c1d7f0d4 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -38,6 +38,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ControlNetInpaintingAliMamaApply", + search_aliases=["masked controlnet"], category="conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f19adf4b9..3eb40e937 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -297,6 +297,7 @@ class ExtendIntermediateSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ExtendIntermediateSigmas", + search_aliases=["interpolate sigmas"], category="sampling/custom_sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), @@ -856,6 +857,7 @@ class DualCFGGuider(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DualCFGGuider", + search_aliases=["dual prompt guidance"], category="sampling/custom_sampling/guiders", inputs=[ io.Model.Input("model"), @@ -883,6 +885,7 @@ class DisableNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DisableNoise", + search_aliases=["zero noise"], category="sampling/custom_sampling/noise", inputs=[], outputs=[io.Noise.Output()] @@ -1019,6 +1022,7 @@ class ManualSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ManualSigmas", + search_aliases=["custom noise schedule", "define sigmas"], category="_for_testing/custom_sampling", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 5ef851bd0..fb9409ac3 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode): class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" - @classmethod def define_schema(cls): return io.Schema( node_id="MakeTrainingDataset", + search_aliases=["encode dataset"], display_name="Make Training Dataset", category="dataset", is_experimental=True, @@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode): class SaveTrainingDataset(io.ComfyNode): """Save encoded training dataset (latents + conditioning) to disk.""" - @classmethod def define_schema(cls): return io.Schema( node_id="SaveTrainingDataset", + search_aliases=["export training data"], display_name="Save Training Dataset", category="dataset", is_experimental=True, @@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode): class LoadTrainingDataset(io.ComfyNode): """Load encoded training dataset from disk.""" - @classmethod def define_schema(cls): return io.Schema( node_id="LoadTrainingDataset", + search_aliases=["import dataset", "training data"], display_name="Load Training Dataset", category="dataset", is_experimental=True, diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 6dfdf466c..34ffb9a89 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DifferentialDiffusion", + search_aliases=["inpaint gradient", "variable denoise strength"], display_name="Differential Diffusion", category="_for_testing", inputs=[ diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index f308eb0c1..3d590af4b 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -58,6 +58,7 @@ class FreSca(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FreSca", + search_aliases=["frequency guidance"], display_name="FreSca", category="_for_testing", description="Applies frequency-dependent scaling to the guidance", diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index eee683ee1..e345fe51d 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeHiDream", + search_aliases=["hidream prompt"], category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 1edc06f3d..58e511ef5 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -259,6 +259,7 @@ class SetClipHooks: return (clip,) class ConditioningTimestepsRange: + SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"] NodeId = 'ConditioningTimestepsRange' NodeName = 'Timesteps Range' @classmethod @@ -468,6 +469,7 @@ class SetHookKeyframes: return (hooks,) class CreateHookKeyframe: + SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"] NodeId = 'CreateHookKeyframe' NodeName = 'Create Hook Keyframe' @classmethod @@ -497,6 +499,7 @@ class CreateHookKeyframe: return (prev_hook_kf,) class CreateHookKeyframesInterpolated: + SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"] NodeId = 'CreateHookKeyframesInterpolated' NodeName = 'Create Hook Keyframes Interp.' @classmethod @@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated: return (prev_hook_kf,) class CreateHookKeyframesFromFloats: + SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"] NodeId = 'CreateHookKeyframesFromFloats' NodeName = 'Create Hook Keyframes From Floats' @classmethod @@ -618,6 +622,7 @@ class SetModelHooksOnCond: # Combine Hooks #------------------------------------------ class CombineHooks: + SEARCH_ALIASES = ["merge hooks"] NodeId = 'CombineHooks2' NodeName = 'Combine Hooks [2]' @classmethod diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index adca14f62..5bb5df48e 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveGLB", + search_aliases=["export 3d model", "save mesh"], category="3d", is_output_node=True, inputs=[ diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index ce21caade..cb4fb24a1 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -22,6 +22,7 @@ class ImageCrop(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCrop", + search_aliases=["trim"], display_name="Image Crop", category="image/transform", inputs=[ @@ -51,6 +52,7 @@ class RepeatImageBatch(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="RepeatImageBatch", + search_aliases=["duplicate image", "clone image"], category="image/batch", inputs=[ IO.Image.Input("image"), @@ -72,6 +74,7 @@ class ImageFromBatch(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageFromBatch", + search_aliases=["select image", "pick from batch", "extract image"], category="image/batch", inputs=[ IO.Image.Input("image"), @@ -97,6 +100,7 @@ class ImageAddNoise(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageAddNoise", + search_aliases=["film grain"], category="image", inputs=[ IO.Image.Input("image"), @@ -194,11 +198,11 @@ class SaveAnimatedPNG(IO.ComfyNode): class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" - @classmethod def define_schema(cls): return IO.Schema( node_id="ImageStitch", + search_aliases=["combine images", "join images", "concatenate images", "side by side"], display_name="Image Stitch", description="Stitches image2 to image1 in the specified direction.\n" "If image2 is not provided, returns image1 unchanged.\n" @@ -369,11 +373,11 @@ class ImageStitch(IO.ComfyNode): class ResizeAndPadImage(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="ResizeAndPadImage", + search_aliases=["fit to size"], category="image/transform", inputs=[ IO.Image.Input("image"), @@ -420,11 +424,11 @@ class ResizeAndPadImage(IO.ComfyNode): class SaveSVGNode(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="SaveSVGNode", + search_aliases=["export vector", "save vector graphics"], description="Save SVG files on disk.", category="image/save", inputs=[ @@ -492,11 +496,11 @@ class SaveSVGNode(IO.ComfyNode): class GetImageSize(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="GetImageSize", + search_aliases=["dimensions", "resolution", "image info"], display_name="Get Image Size", description="Returns width and height of the image, and passes it through unchanged.", category="image", @@ -527,11 +531,11 @@ class GetImageSize(IO.ComfyNode): class ImageRotate(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="ImageRotate", + search_aliases=["turn", "flip orientation"], category="image/transform", inputs=[ IO.Image.Input("image"), @@ -557,11 +561,11 @@ class ImageRotate(IO.ComfyNode): class ImageFlip(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="ImageFlip", + search_aliases=["mirror", "reflect"], category="image/transform", inputs=[ IO.Image.Input("image"), diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 9cb234be1..346c50cde 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -104,6 +104,7 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeKandinsky5", + search_aliases=["kandinsky prompt"], category="advanced/conditioning/kandinsky5", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 9ba1c4ba8..6aecf1561 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -21,6 +21,7 @@ class LatentAdd(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentAdd", + search_aliases=["combine latents", "sum latents"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -47,6 +48,7 @@ class LatentSubtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentSubtract", + search_aliases=["difference latent", "remove features"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -73,6 +75,7 @@ class LatentMultiply(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentMultiply", + search_aliases=["scale latent", "amplify latent", "latent gain"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -96,6 +99,7 @@ class LatentInterpolate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentInterpolate", + search_aliases=["blend latent", "mix latent", "lerp latent", "transition"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -134,6 +138,7 @@ class LatentConcat(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentConcat", + search_aliases=["join latents", "stitch latents"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -173,6 +178,7 @@ class LatentCut(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentCut", + search_aliases=["crop latent", "slice latent", "extract region"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -213,6 +219,7 @@ class LatentCutToBatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentCutToBatch", + search_aliases=["slice to batch", "split latent", "tile latent"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -254,6 +261,7 @@ class LatentBatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentBatch", + search_aliases=["combine latents", "merge latents", "join latents"], category="latent/batch", is_deprecated=True, inputs=[ @@ -310,6 +318,7 @@ class LatentApplyOperation(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentApplyOperation", + search_aliases=["transform latent"], category="latent/advanced/operations", is_experimental=True, inputs=[ @@ -365,6 +374,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentOperationTonemapReinhard", + search_aliases=["hdr latent"], category="latent/advanced/operations", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index a16b8c8f3..4b8d950ae 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -75,6 +75,7 @@ class Preview3D(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Preview3D", + search_aliases=["view mesh", "3d viewer"], display_name="Preview 3D & Animation", category="3d", is_experimental=True, diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index eb888316a..1ed060205 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -224,6 +224,7 @@ class ConvertStringToComboNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ConvertStringToComboNode", + search_aliases=["string to dropdown", "text to combo"], display_name="Convert String to Combo", category="logic", inputs=[io.String.Input("string")], @@ -239,6 +240,7 @@ class InvertBooleanNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="InvertBooleanNode", + search_aliases=["not", "toggle", "negate", "flip boolean"], display_name="Invert Boolean", category="logic", inputs=[io.Boolean.Input("boolean")], diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index a2375cba7..fb89e03f4 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -78,6 +78,7 @@ class LoraSave(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LoraSave", + search_aliases=["export lora"], display_name="Extract and Save Lora", category="_for_testing", inputs=[ diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index 89ff2397a..2550475ae 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeLumina2", + search_aliases=["lumina prompt"], display_name="CLIP Text Encode for Lumina2", category="conditioning", description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 290e6f55e..98e8fef8f 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -50,6 +50,7 @@ class LatentCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="LatentCompositeMasked", + search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], category="latent", inputs=[ IO.Latent.Input("destination"), @@ -78,6 +79,7 @@ class ImageCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompositeMasked", + search_aliases=["paste image", "overlay", "layer"], category="image", inputs=[ IO.Image.Input("destination"), @@ -105,6 +107,7 @@ class MaskToImage(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskToImage", + search_aliases=["convert mask"], display_name="Convert Mask to Image", category="mask", inputs=[ @@ -126,6 +129,7 @@ class ImageToMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageToMask", + search_aliases=["extract channel", "channel to mask"], display_name="Convert Image to Mask", category="mask", inputs=[ @@ -149,6 +153,7 @@ class ImageColorToMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageColorToMask", + search_aliases=["color keying", "chroma key"], category="mask", inputs=[ IO.Image.Input("image"), @@ -194,6 +199,7 @@ class InvertMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="InvertMask", + search_aliases=["reverse mask", "flip mask"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -214,6 +220,7 @@ class CropMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="CropMask", + search_aliases=["cut mask", "extract mask region", "mask slice"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -239,6 +246,7 @@ class MaskComposite(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskComposite", + search_aliases=["combine masks", "blend masks", "layer masks"], category="mask", inputs=[ IO.Mask.Input("destination"), @@ -287,6 +295,7 @@ class FeatherMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FeatherMask", + search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -333,6 +342,7 @@ class GrowMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="GrowMask", + search_aliases=["expand mask", "shrink mask"], display_name="Grow Mask", category="mask", inputs=[ @@ -370,6 +380,7 @@ class ThresholdMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ThresholdMask", + search_aliases=["binary mask"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -394,6 +405,7 @@ class MaskPreview(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskPreview", + search_aliases=["show mask", "view mask", "inspect mask", "debug mask"], display_name="Preview Mask", category="mask", description="Saves the input images to your ComfyUI output directory.", diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ae5d2c563..f22b333fc 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -299,6 +299,7 @@ class RescaleCFG: return (m, ) class ModelComputeDtype: + SEARCH_ALIASES = ["model precision", "change dtype"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index f20beab7d..5384ed531 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -91,6 +91,7 @@ class CLIPMergeSimple: class CLIPSubtract: + SEARCH_ALIASES = ["clip difference", "text encoder subtract"] @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), @@ -113,6 +114,7 @@ class CLIPSubtract: class CLIPAdd: + SEARCH_ALIASES = ["combine clip"] @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), @@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: + SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"] def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -337,6 +340,7 @@ class VAESave: return {} class ModelSave: + SEARCH_ALIASES = ["export model", "checkpoint save"] def __init__(self): self.output_dir = folder_paths.get_output_directory() diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 67377e1bc..4ab2fb7e8 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -12,6 +12,7 @@ class Morphology(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Morphology", + search_aliases=["erode", "dilate"], display_name="ImageMorphology", category="image/postprocessing", inputs=[ @@ -57,6 +58,7 @@ class ImageRGBToYUV(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageRGBToYUV", + search_aliases=["color space conversion"], category="image/batch", inputs=[ io.Image.Input("image"), @@ -78,6 +80,7 @@ class ImageYUVToRGB(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageYUVToRGB", + search_aliases=["color space conversion"], category="image/batch", inputs=[ io.Image.Input("Y"), diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py index a23e87b1f..2f1b73e60 100644 --- a/comfy_extras/nodes_pixart.py +++ b/comfy_extras/nodes_pixart.py @@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodePixArtAlpha", + search_aliases=["pixart prompt"], category="advanced/conditioning", description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", inputs=[ diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 6011275d6..ab002daca 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -402,7 +402,6 @@ def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: st return input[:, y0:y1, x0:x1] class ResizeImageMaskNode(io.ComfyNode): - scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @@ -424,6 +423,7 @@ class ResizeImageMaskNode(io.ComfyNode): crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") return io.Schema( node_id="ResizeImageMaskNode", + search_aliases=["scale image", "scale mask"], display_name="Resize Image/Mask", category="transform", inputs=[ @@ -569,6 +569,7 @@ class BatchMasksNode(io.ComfyNode): autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) return io.Schema( node_id="BatchMasksNode", + search_aliases=["combine masks", "stack masks", "merge masks"], display_name="Batch Masks", category="mask", inputs=[ @@ -589,6 +590,7 @@ class BatchLatentsNode(io.ComfyNode): autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) return io.Schema( node_id="BatchLatentsNode", + search_aliases=["combine latents", "stack latents", "merge latents"], display_name="Batch Latents", category="latent", inputs=[ @@ -612,6 +614,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): prefix="input", min=1, max=50) return io.Schema( node_id="BatchImagesMasksLatentsNode", + search_aliases=["combine batch", "merge batch", "stack inputs"], display_name="Batch Images/Masks/Latents", category="util", inputs=[ diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 91502ebf2..b0a6f279d 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,7 +16,7 @@ class PreviewAny(): OUTPUT_NODE = True CATEGORY = "utils" - SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"] + SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 14782cb2b..02e5e7dd8 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -65,6 +65,7 @@ class CLIPTextEncodeSD3(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeSD3", + search_aliases=["sd3 prompt"], category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index a2d5f0d94..8d3e65cc5 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -32,6 +32,7 @@ class StringSubstring(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringSubstring", + search_aliases=["extract text", "text portion"], display_name="Substring", category="utils/string", inputs=[ @@ -54,6 +55,7 @@ class StringLength(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringLength", + search_aliases=["character count", "text size"], display_name="Length", category="utils/string", inputs=[ @@ -74,6 +76,7 @@ class CaseConverter(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CaseConverter", + search_aliases=["text case", "uppercase", "lowercase", "capitalize"], display_name="Case Converter", category="utils/string", inputs=[ @@ -106,6 +109,7 @@ class StringTrim(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringTrim", + search_aliases=["clean whitespace", "remove whitespace"], display_name="Trim", category="utils/string", inputs=[ @@ -136,6 +140,7 @@ class StringReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringReplace", + search_aliases=["find and replace", "substitute", "swap text"], display_name="Replace", category="utils/string", inputs=[ @@ -158,6 +163,7 @@ class StringContains(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringContains", + search_aliases=["text includes", "string includes"], display_name="Contains", category="utils/string", inputs=[ @@ -185,6 +191,7 @@ class StringCompare(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringCompare", + search_aliases=["text match", "string equals", "starts with", "ends with"], display_name="Compare", category="utils/string", inputs=[ @@ -220,6 +227,7 @@ class RegexMatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexMatch", + search_aliases=["pattern match", "text contains", "string match"], display_name="Regex Match", category="utils/string", inputs=[ @@ -260,6 +268,7 @@ class RegexExtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexExtract", + search_aliases=["pattern extract", "text parser", "parse text"], display_name="Regex Extract", category="utils/string", inputs=[ @@ -334,6 +343,7 @@ class RegexReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexReplace", + search_aliases=["pattern replace", "find and replace", "substitution"], display_name="Regex Replace", category="utils/string", description="Find and replace text using regex patterns.", diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 364804205..68a73cf13 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1101,6 +1101,7 @@ class SaveLoRA(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveLoRA", + search_aliases=["export lora"], display_name="Save LoRA Weights", category="loaders", is_experimental=True, @@ -1144,6 +1145,7 @@ class LossGraphNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LossGraphNode", + search_aliases=["training chart", "training visualization", "plot loss"], display_name="Plot Loss Graph", category="training", is_experimental=True, diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index c609e03da..ccf7b63d3 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -16,6 +16,7 @@ class SaveWEBM(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveWEBM", + search_aliases=["export webm"], category="image/video", is_experimental=True, inputs=[ @@ -69,6 +70,7 @@ class SaveVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveVideo", + search_aliases=["export video"], display_name="Save Video", category="image/video", description="Saves the input images to your ComfyUI output directory.", @@ -116,6 +118,7 @@ class CreateVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CreateVideo", + search_aliases=["images to video"], display_name="Create Video", category="image/video", description="Create a video from images.", @@ -140,6 +143,7 @@ class GetVideoComponents(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="GetVideoComponents", + search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="image/video", description="Extracts all components from a video: frames, audio, and framerate.", @@ -167,6 +171,7 @@ class LoadVideo(io.ComfyNode): files = folder_paths.filter_files_content_types(files, ["video"]) return io.Schema( node_id="LoadVideo", + search_aliases=["import video", "open video", "video file"], display_name="Load Video", category="image/video", inputs=[ diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 90deb0077..2ff012134 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -287,6 +287,7 @@ class WanVaceToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanVaceToVideo", + search_aliases=["video conditioning", "video control"], category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), @@ -705,6 +706,7 @@ class WanTrackToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanTrackToVideo", + search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"], category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py index 5f39afa46..d60baf230 100644 --- a/comfy_extras/nodes_wanmove.py +++ b/comfy_extras/nodes_wanmove.py @@ -324,6 +324,7 @@ class GenerateTracks(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="GenerateTracks", + search_aliases=["motion paths", "camera movement", "trajectory"], category="conditioning/video_models", inputs=[ io.Int.Input("width", default=832, min=16, max=4096, step=16), diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py index 5bf80b4c6..6349ac017 100644 --- a/comfy_extras/nodes_webcam.py +++ b/comfy_extras/nodes_webcam.py @@ -5,6 +5,7 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION class WebcamCapture(nodes.LoadImage): + SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"] @classmethod def INPUT_TYPES(s): return { diff --git a/nodes.py b/nodes.py index 8864fda60..158106686 100644 --- a/nodes.py +++ b/nodes.py @@ -93,6 +93,8 @@ class ConditioningCombine: return (conditioning_1 + conditioning_2, ) class ConditioningAverage : + SEARCH_ALIASES = ["blend prompts", "interpolate conditioning", "mix prompts", "style fusion", "weighted blend"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), @@ -159,6 +161,8 @@ class ConditioningConcat: return (out, ) class ConditioningSetArea: + SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -217,6 +221,8 @@ class ConditioningSetAreaStrength: class ConditioningSetMask: + SEARCH_ALIASES = ["masked prompt", "regional inpaint conditioning", "mask conditioning"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -242,6 +248,8 @@ class ConditioningSetMask: return (c, ) class ConditioningZeroOut: + SEARCH_ALIASES = ["null conditioning", "clear conditioning"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", )}} @@ -467,6 +475,8 @@ class InpaintModelConditioning: class SaveLatent: + SEARCH_ALIASES = ["export latent"] + def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -518,6 +528,8 @@ class SaveLatent: class LoadLatent: + SEARCH_ALIASES = ["import latent", "open latent"] + @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() @@ -554,6 +566,8 @@ class LoadLatent: class CheckpointLoader: + SEARCH_ALIASES = ["load model", "model loader"] + @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), @@ -593,6 +607,8 @@ class CheckpointLoaderSimple: return out[:3] class DiffusersLoader: + SEARCH_ALIASES = ["load diffusers model"] + @classmethod def INPUT_TYPES(cls): paths = [] @@ -1063,6 +1079,8 @@ class StyleModelLoader: class StyleModelApply: + SEARCH_ALIASES = ["style transfer"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -1216,6 +1234,8 @@ class EmptyLatentImage: class LatentFromBatch: + SEARCH_ALIASES = ["select from batch", "pick latent", "batch subset"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1248,6 +1268,8 @@ class LatentFromBatch: return (s,) class RepeatLatentBatch: + SEARCH_ALIASES = ["duplicate latent", "clone latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1274,6 +1296,8 @@ class RepeatLatentBatch: return (s,) class LatentUpscale: + SEARCH_ALIASES = ["enlarge latent", "resize latent"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] crop_methods = ["disabled", "center"] @@ -1308,6 +1332,8 @@ class LatentUpscale: return (s,) class LatentUpscaleBy: + SEARCH_ALIASES = ["enlarge latent", "resize latent", "scale latent"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] @classmethod @@ -1351,6 +1377,8 @@ class LatentRotate: return (s,) class LatentFlip: + SEARCH_ALIASES = ["mirror latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1371,6 +1399,8 @@ class LatentFlip: return (s,) class LatentComposite: + SEARCH_ALIASES = ["overlay latent", "layer latent", "paste latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples_to": ("LATENT",), @@ -1413,6 +1443,8 @@ class LatentComposite: return (samples_out,) class LatentBlend: + SEARCH_ALIASES = ["mix latents", "interpolate latents"] + @classmethod def INPUT_TYPES(s): return {"required": { @@ -1454,6 +1486,8 @@ class LatentBlend: raise ValueError(f"Unsupported blend mode: {mode}") class LatentCrop: + SEARCH_ALIASES = ["trim latent", "cut latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1739,6 +1773,8 @@ class LoadImage: return True class LoadImageMask: + SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): @@ -1789,6 +1825,8 @@ class LoadImageMask: class LoadImageOutput(LoadImage): + SEARCH_ALIASES = ["output image", "previous generation"] + @classmethod def INPUT_TYPES(s): return { @@ -1862,6 +1900,7 @@ class ImageScaleBy: return (s,) class ImageInvert: + SEARCH_ALIASES = ["reverse colors"] @classmethod def INPUT_TYPES(s): @@ -1877,6 +1916,7 @@ class ImageInvert: return (s,) class ImageBatch: + SEARCH_ALIASES = ["combine images", "merge images", "stack images"] @classmethod def INPUT_TYPES(s): @@ -1922,6 +1962,7 @@ class EmptyImage: return (torch.cat((r, g, b), dim=-1), ) class ImagePadForOutpaint: + SEARCH_ALIASES = ["extend canvas", "expand image"] @classmethod def INPUT_TYPES(s): From 4e3038114a725d15166a726860a29cbab0dda4e3 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 22 Jan 2026 18:46:55 -0800 Subject: [PATCH 41/58] feat: Improve ResizeImageMaskNode UX with tooltips and search aliases (#12013) - Add search_aliases for discoverability: resize, scale, dimensions, etc. - Add node description for hover tooltip - Add tooltips to all inputs explaining their behavior - Reorder options: most common (scale dimensions) first, most technical (scale to multiple) last Addresses user feedback that 'resize' search returned nothing useful and options like 'match size' and 'scale to multiple' were not self-explanatory. --- comfy_extras/nodes_post_processing.py | 66 +++++++++++++++++---------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ab002daca..32ab2f70d 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -420,47 +420,63 @@ class ResizeImageMaskNode(io.ComfyNode): @classmethod def define_schema(cls): template = io.MatchType.Template("input_type", [io.Image, io.Mask]) - crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + crop_combo = io.Combo.Input( + "crop", + options=cls.crop_methods, + default="center", + tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.", + ) return io.Schema( node_id="ResizeImageMaskNode", search_aliases=["scale image", "scale mask"], display_name="Resize Image/Mask", + description="Resize an image or mask using various scaling methods.", category="transform", + search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "image resize", "change size", "dimensions", "shrink", "enlarge"], inputs=[ io.MatchType.Input("input", template=template), - io.DynamicCombo.Input("resize_type", options=[ - io.DynamicCombo.Option(ResizeType.SCALE_BY, [ - io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + io.DynamicCombo.Input( + "resize_type", + tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.", + options=[ + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), - crop_combo, + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ - io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ - io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ - io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ - io.MultiType.Input("match", [io.Image, io.Mask]), - crop_combo, + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ - io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."), ]), - ]), - io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + ), + io.Combo.Input( + "scale_method", + options=cls.scale_methods, + default="area", + tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.", + ), ], outputs=[io.MatchType.Output(template=template, display_name="resized")] ) From f443b9f2ca3109f7e3ef6c5de3cdd22330fbf34c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 22 Jan 2026 20:02:37 -0800 Subject: [PATCH 42/58] =?UTF-8?q?Revert=20"feat:=20Improve=20ResizeImageMa?= =?UTF-8?q?skNode=20UX=20with=20tooltips=20and=20search=20aliases=E2=80=A6?= =?UTF-8?q?"=20(#12038)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 4e3038114a725d15166a726860a29cbab0dda4e3. --- comfy_extras/nodes_post_processing.py | 66 ++++++++++----------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 32ab2f70d..ab002daca 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -420,63 +420,47 @@ class ResizeImageMaskNode(io.ComfyNode): @classmethod def define_schema(cls): template = io.MatchType.Template("input_type", [io.Image, io.Mask]) - crop_combo = io.Combo.Input( - "crop", - options=cls.crop_methods, - default="center", - tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.", - ) + crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") return io.Schema( node_id="ResizeImageMaskNode", search_aliases=["scale image", "scale mask"], display_name="Resize Image/Mask", - description="Resize an image or mask using various scaling methods.", category="transform", - search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "image resize", "change size", "dimensions", "shrink", "enlarge"], inputs=[ io.MatchType.Input("input", template=template), - io.DynamicCombo.Input( - "resize_type", - tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.", - options=[ - io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."), - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."), - crop_combo, + io.DynamicCombo.Input("resize_type", options=[ + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), ]), - io.DynamicCombo.Option(ResizeType.SCALE_BY, [ - io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."), + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ - io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), ]), - io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ - io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), ]), - io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), ]), - io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), ]), - io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ - io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), ]), - io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ - io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."), - crop_combo, + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask]), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ - io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), ]), - ], - ), - io.Combo.Input( - "scale_method", - options=cls.scale_methods, - default="area", - tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.", - ), + ]), + io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), ], outputs=[io.MatchType.Output(template=template, display_name="resized")] ) From 79cdbc81cb552b363430d1e88c98c4b4b4b4cf62 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 22 Jan 2026 22:04:27 -0800 Subject: [PATCH 43/58] feat: Improve ResizeImageMaskNode UX with tooltips and search aliases (#12040) - Add search_aliases for discoverability: resize, scale, dimensions, etc. - Add node description for hover tooltip - Add tooltips to all inputs explaining their behavior - Reorder options: most common (scale dimensions) first, most technical (scale to multiple) last Addresses user feedback that 'resize' search returned nothing useful and options like 'match size' and 'scale to multiple' were not self-explanatory. --- comfy_extras/nodes_post_processing.py | 67 ++++++++++++++++----------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ab002daca..a52a90e2c 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -420,47 +420,62 @@ class ResizeImageMaskNode(io.ComfyNode): @classmethod def define_schema(cls): template = io.MatchType.Template("input_type", [io.Image, io.Mask]) - crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + crop_combo = io.Combo.Input( + "crop", + options=cls.crop_methods, + default="center", + tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.", + ) return io.Schema( node_id="ResizeImageMaskNode", - search_aliases=["scale image", "scale mask"], display_name="Resize Image/Mask", + description="Resize an image or mask using various scaling methods.", category="transform", + search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"], inputs=[ io.MatchType.Input("input", template=template), - io.DynamicCombo.Input("resize_type", options=[ - io.DynamicCombo.Option(ResizeType.SCALE_BY, [ - io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + io.DynamicCombo.Input( + "resize_type", + tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.", + options=[ + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), - crop_combo, + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ - io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ - io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ - io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ - io.MultiType.Input("match", [io.Image, io.Mask]), - crop_combo, + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ - io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."), ]), - ]), - io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + ), + io.Combo.Input( + "scale_method", + options=cls.scale_methods, + default="area", + tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.", + ), ], outputs=[io.MatchType.Output(template=template, display_name="resized")] ) From 55bd606e92ea0a0ef1cc83a7fa4f6decf0128b12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 23 Jan 2026 22:26:38 +0200 Subject: [PATCH 44/58] LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (#12046) * Disable timestep embed compression when inpainting Spatial inpainting not compatible with the compression * Reduce crossattn peak VRAM * LTX2: Refactor forward function for better VRAM efficiency --- comfy/ldm/lightricks/av_model.py | 230 +++++++++++++------------------ 1 file changed, 94 insertions(+), 136 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index c12ace241..2c6954ecd 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -18,12 +18,12 @@ class CompressedTimestep: def __init__(self, tensor: torch.Tensor, patches_per_frame: int): """ tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame - patches_per_frame: Number of spatial patches per frame (height * width in latent space) + patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression """ self.batch_size, num_tokens, self.feature_dim = tensor.shape # Check if compression is valid (num_tokens must be divisible by patches_per_frame) - if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: self.patches_per_frame = patches_per_frame self.num_frames = num_tokens // patches_per_frame @@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) def forward( - self, - x: Tuple[torch.Tensor, torch.Tensor], - v_context=None, - a_context=None, - attention_mask=None, - v_timestep=None, - a_timestep=None, - v_pe=None, - a_pe=None, - v_cross_pe=None, - a_cross_pe=None, - v_cross_scale_shift_timestep=None, - a_cross_scale_shift_timestep=None, - v_cross_gate_timestep=None, - a_cross_gate_timestep=None, - transformer_options=None, + self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, + v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module): run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + # video if run_vx: - vshift_msa, vscale_msa, vgate_msa = ( - self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) - ) - + # video self-attention + vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa - vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa - vx += self.attn2( - comfy.ldm.common_dit.rms_norm(vx), - context=v_context, - mask=attention_mask, - transformer_options=transformer_options, - ) - - del vshift_msa, vscale_msa, vgate_msa + del vshift_msa, vscale_msa + attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + del norm_vx + # video cross-attention + vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] + vx.addcmul_(attn1_out, vgate_msa) + del vgate_msa, attn1_out + vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + # audio if run_ax: - ashift_msa, ascale_msa, agate_msa = ( - self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) - ) - + # audio self-attention + ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2))) norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa - ax += ( - self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) - * agate_msa - ) - ax += self.audio_attn2( - comfy.ldm.common_dit.rms_norm(ax), - context=a_context, - mask=attention_mask, - transformer_options=transformer_options, - ) + del ashift_msa, ascale_msa + attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + del norm_ax + # audio cross-attention + agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] + ax.addcmul_(attn1_out, agate_msa) + del agate_msa, attn1_out + ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) - del ashift_msa, ascale_msa, agate_msa - - # Audio - Video cross attention. + # video - audio cross attention. if run_a2v or run_v2a: - # norm3 vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) - ( - scale_ca_audio_hidden_states_a2v, - shift_ca_audio_hidden_states_a2v, - scale_ca_audio_hidden_states_v2a, - shift_ca_audio_hidden_states_v2a, - gate_out_v2a, - ) = self.get_av_ca_ada_values( - self.scale_shift_table_a2v_ca_audio, - ax.shape[0], - a_cross_scale_shift_timestep, - a_cross_gate_timestep, - ) - - ( - scale_ca_video_hidden_states_a2v, - shift_ca_video_hidden_states_a2v, - scale_ca_video_hidden_states_v2a, - shift_ca_video_hidden_states_v2a, - gate_out_a2v, - ) = self.get_av_ca_ada_values( - self.scale_shift_table_a2v_ca_video, - vx.shape[0], - v_cross_scale_shift_timestep, - v_cross_gate_timestep, - ) - + # audio to video cross attention if run_a2v: - vx_scaled = ( - vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) - + shift_ca_video_hidden_states_a2v - ) - ax_scaled = ( - ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) - + shift_ca_audio_hidden_states_a2v - ) - vx += ( - self.audio_to_video_attn( - vx_scaled, - context=ax_scaled, - pe=v_cross_pe, - k_pe=a_cross_pe, - transformer_options=transformer_options, - ) - * gate_out_a2v - ) + scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2] + scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2] - del gate_out_a2v - del scale_ca_video_hidden_states_a2v,\ - shift_ca_video_hidden_states_a2v,\ - scale_ca_audio_hidden_states_a2v,\ - shift_ca_audio_hidden_states_a2v,\ + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v + a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options) + del vx_scaled, ax_scaled + + gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0] + vx.addcmul_(a2v_out, gate_out_a2v) + del gate_out_a2v, a2v_out + + # video to audio cross attention if run_v2a: - ax_scaled = ( - ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) - + shift_ca_audio_hidden_states_v2a - ) - vx_scaled = ( - vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) - + shift_ca_video_hidden_states_v2a - ) - ax += ( - self.video_to_audio_attn( - ax_scaled, - context=vx_scaled, - pe=a_cross_pe, - k_pe=v_cross_pe, - transformer_options=transformer_options, - ) - * gate_out_v2a - ) + scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4] + scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4] - del gate_out_v2a - del scale_ca_video_hidden_states_v2a,\ - shift_ca_video_hidden_states_v2a,\ - scale_ca_audio_hidden_states_v2a,\ - shift_ca_audio_hidden_states_v2a + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a + v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options) + del ax_scaled, vx_scaled + + gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0] + ax.addcmul_(v2a_out, gate_out_v2a) + del gate_out_v2a, v2a_out + + del vx_norm3, ax_norm3 + + # video feedforward if run_vx: - vshift_mlp, vscale_mlp, vgate_mlp = ( - self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) - ) - + vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5)) vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp - vx += self.ff(vx_scaled) * vgate_mlp - del vshift_mlp, vscale_mlp, vgate_mlp + del vshift_mlp, vscale_mlp + ff_out = self.ff(vx_scaled) + del vx_scaled + + vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0] + vx.addcmul_(ff_out, vgate_mlp) + del vgate_mlp, ff_out + + # audio feedforward if run_ax: - ashift_mlp, ascale_mlp, agate_mlp = ( - self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) - ) - + ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5)) ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp - ax += self.audio_ff(ax_scaled) * agate_mlp + del ashift_mlp, ascale_mlp - del ashift_mlp, ascale_mlp, agate_mlp + ff_out = self.audio_ff(ax_scaled) + del ax_scaled + agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0] + ax.addcmul_(ff_out, agate_mlp) + del agate_mlp, ff_out return vx, ax @@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel): audio_length = kwargs.get("audio_length", 0) # Separate audio and video latents vx, ax = self.separate_audio_and_video_latents(x, audio_length) + + has_spatial_mask = False + if denoise_mask is not None: + # check if any frame has spatial variation (inpainting) + for frame_idx in range(denoise_mask.shape[2]): + frame_mask = denoise_mask[0, 0, frame_idx] + if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max(): + has_spatial_mask = True + break + [vx, v_pixel_coords, additional_args] = super()._process_input( vx, keyframe_idxs, denoise_mask, **kwargs ) + additional_args["has_spatial_mask"] = has_spatial_mask ax, a_latent_coords = self.a_patchifier.patchify(ax) ax = self.audio_patchify_proj(ax) @@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel): # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width orig_shape = kwargs.get("orig_shape") + has_spatial_mask = kwargs.get("has_spatial_mask", None) v_patches_per_frame = None - if orig_shape is not None and len(orig_shape) == 5: + if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: # orig_shape[3] = height, orig_shape[4] = width (in latent space) v_patches_per_frame = orig_shape[3] * orig_shape[4] @@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel): ) # Compress cross-attention timesteps (only video side, audio is too small to benefit) + # v_patches_per_frame is None for spatial masks, set for temporal masks or no mask cross_av_timestep_ss = [ av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), - CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed - CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), ] From e89b22993aa2e2b27f4ab1585754cee3a7ca1ff5 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 24 Jan 2026 04:27:49 +0800 Subject: [PATCH 45/58] Support ModelScope-Trainer/DiffSynth LoRA format for Flux.2 Klein models (#12042) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index e8246bd66..7b31d055c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + key_map[k[:-len(".weight")]] = to #DiffSynth lora format for k in sdk: hidden_size = model.model_config.unet_config.get("hidden_size", 0) if k.endswith(".weight") and ".linear1." in k: From 9cf299a9f9488e4cb9b3f7cef3bc94c185c19f73 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:50:48 -0800 Subject: [PATCH 46/58] Make regular empty latent node work properly on flux 2 variants. (#12050) --- comfy/latent_formats.py | 3 +++ comfy/sample.py | 12 +++++++++--- comfy_extras/nodes_custom_sampler.py | 6 ++++-- comfy_extras/nodes_sd3.py | 2 +- nodes.py | 5 +++-- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index cb4f52ce1..5600825ed 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -8,6 +8,7 @@ class LatentFormat: latent_rgb_factors_bias = None latent_rgb_factors_reshape = None taesd_decoder_name = None + spacial_downscale_ratio = 8 def process_in(self, latent): return latent * self.scale_factor @@ -181,6 +182,7 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + spacial_downscale_ratio = 16 def __init__(self): self.latent_rgb_factors =[ @@ -749,6 +751,7 @@ class ACEAudio(LatentFormat): class ChromaRadiance(LatentFormat): latent_channels = 3 + spacial_downscale_ratio = 1 def __init__(self): self.latent_rgb_factors = [ diff --git a/comfy/sample.py b/comfy/sample.py index 2f8f3a51c..a2a39b527 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None): return noises -def fix_empty_latent_channels(model, latent_image): +def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels - if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if torch.count_nonzero(latent_image) == 0: + if latent_format.latent_channels != latent_image.shape[1]: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if downscale_ratio_spacial is not None: + if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: + ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio + latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled") + if latent_format.latent_dimensions == 3 and latent_image.ndim == 4: latent_image = latent_image.unsqueeze(2) return latent_image diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 3eb40e937..a4d84ddf7 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -741,7 +741,7 @@ class SamplerCustom(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image if not add_noise: @@ -760,6 +760,7 @@ class SamplerCustom(io.ComfyNode): samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -939,7 +940,7 @@ class SamplerCustomAdvanced(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image noise_mask = None @@ -954,6 +955,7 @@ class SamplerCustomAdvanced(io.ComfyNode): samples = samples.to(comfy.model_management.intermediate_device()) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 02e5e7dd8..736213a47 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode): @classmethod def execute(cls, width, height, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove diff --git a/nodes.py b/nodes.py index 158106686..b75247665 100644 --- a/nodes.py +++ b/nodes.py @@ -1230,7 +1230,7 @@ class EmptyLatentImage: def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + return ({"samples": latent, "downscale_ratio_spacial": 8}, ) class LatentFromBatch: @@ -1538,7 +1538,7 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") @@ -1556,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples return (out, ) From 4e6a1b66a93ef91848bc4bbf2a84e0ea98efcfc9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:56:14 -0800 Subject: [PATCH 47/58] speed up and reduce VRAM of QWEN VAE and WAN (less so) (#12036) * ops: introduce autopad for conv3d This works around pytorch missing ability to causal pad as part of the kernel and avoids massive weight duplications for padding. * wan-vae: rework causal padding This currently uses F.pad which takes a full deep copy and is liable to be the VRAM peak. Instead, kick spatial padding back to the op and consolidate the temporal padding with the cat for the cache. * wan-vae: implement zero pad fast path The WAN VAE is also QWEN where it is used single-image. These convolutions are however zero padded 3d convolutions, which means the VAE is actually just 2D down the last element of the conv weight in the temporal dimension. Fast path this, to avoid adding zeros that then just evaporate in convoluton math but cost computation. --- comfy/ldm/wan/vae.py | 27 +++++++++++++++++---------- comfy/ops.py | 10 ++++++---- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 08315f1a8..40e767213 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy.ldm.modules.diffusionmodules.model import vae_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed import comfy.ops ops = comfy.ops.disable_weight_init @@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + self._padding = 2 * self.padding[0] + self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): if cache_list is not None: cache_x = cache_list[cache_idx] cache_list[cache_idx] = None - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] + if cache_x is None and x.shape[2] == 1: + #Fast path - the op will pad for use by truncating the weight + #and save math on a pile of zeros. + return super().forward(x, autopad="causal_zero") + + if self._padding > 0: + padding_needed = self._padding + if cache_x is not None: + cache_x = cache_x.to(x.device) + padding_needed = max(0, padding_needed - cache_x.shape[2]) + padding_shape = list(x.shape) + padding_shape[2] = padding_needed + padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype) + x = torch_cat_if_needed([padding, cache_x, x], dim=2) del cache_x - x = F.pad(x, padding) return super().forward(x) diff --git a/comfy/ops.py b/comfy/ops.py index 415c39e92..e406ba7ed 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -203,7 +203,9 @@ class disable_weight_init: def reset_parameters(self): return None - def _conv_forward(self, input, weight, bias, *args, **kwargs): + def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs): + if autopad == "causal_zero": + weight = weight[:, :, -input.shape[2]:, :, :] if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) if bias is not None: @@ -212,15 +214,15 @@ class disable_weight_init: else: return super()._conv_forward(input, weight, bias, *args, **kwargs) - def forward_comfy_cast_weights(self, input): + def forward_comfy_cast_weights(self, input, autopad=None): weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._conv_forward(input, weight, bias) + x = self._conv_forward(input, weight, bias, autopad=autopad) uncast_bias_weight(self, weight, bias, offload_stream) return x def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) From aef4e135889638812fc1ceab6f323d3441b48f5d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 24 Jan 2026 16:23:20 -0800 Subject: [PATCH 48/58] Make empty latent node work with other models. (#12062) --- comfy/latent_formats.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 5600825ed..38f18a83f 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -594,6 +594,7 @@ class Wan22(Wan21): class HunyuanImage21(LatentFormat): latent_channels = 64 latent_dimensions = 2 + spacial_downscale_ratio = 32 scale_factor = 0.75289 latent_rgb_factors = [ @@ -727,6 +728,7 @@ class HunyuanVideo15(LatentFormat): latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] latent_channels = 32 latent_dimensions = 3 + spacial_downscale_ratio = 16 scale_factor = 1.03682 taesd_decoder_name = "lighttaehy1_5" From bc72d7f8d11a664bc59941affc05a3f515239171 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 25 Jan 2026 03:10:09 +0200 Subject: [PATCH 49/58] [API Nodes] add TencentHunyuan3D nodes (#12026) * feat(api-nodes): add TencentHunyuan3D nodes * add "(Pro)" to display name --------- Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/apis/hunyuan3d.py | 66 ++++++ comfy_api_nodes/nodes_hunyuan3d.py | 297 +++++++++++++++++++++++++ comfy_api_nodes/nodes_kling.py | 1 - comfy_api_nodes/nodes_sora.py | 1 - comfy_api_nodes/nodes_topaz.py | 1 - comfy_api_nodes/util/__init__.py | 4 + comfy_api_nodes/util/client.py | 4 +- comfy_api_nodes/util/conversions.py | 15 ++ comfy_api_nodes/util/upload_helpers.py | 22 ++ 9 files changed, 406 insertions(+), 5 deletions(-) create mode 100644 comfy_api_nodes/apis/hunyuan3d.py create mode 100644 comfy_api_nodes/nodes_hunyuan3d.py diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py new file mode 100644 index 000000000..6421c9bd5 --- /dev/null +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -0,0 +1,66 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field, model_validator + + +class InputGenerateType(TypedDict): + generate_type: str + polygon_type: str + pbr: bool + + +class Hunyuan3DViewImage(BaseModel): + ViewType: str = Field(..., description="Valid values: back, left, right.") + ViewImageUrl: str = Field(...) + + +class To3DProTaskRequest(BaseModel): + Model: str = Field(...) + Prompt: str | None = Field(None) + ImageUrl: str | None = Field(None) + MultiViewImages: list[Hunyuan3DViewImage] | None = Field(None) + EnablePBR: bool | None = Field(...) + FaceCount: int | None = Field(...) + GenerateType: str | None = Field(...) + PolygonType: str | None = Field(...) + + +class RequestError(BaseModel): + Code: str = Field("") + Message: str = Field("") + + +class To3DProTaskCreateResponse(BaseModel): + JobId: str | None = Field(None) + Error: RequestError | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class ResultFile3D(BaseModel): + Type: str = Field(...) + Url: str = Field(...) + PreviewImageUrl: str = Field("") + + +class To3DProTaskResultResponse(BaseModel): + ErrorCode: str = Field("") + ErrorMessage: str = Field("") + ResultFile3Ds: list[ResultFile3D] = Field([]) + Status: str = Field(...) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class To3DProTaskQueryRequest(BaseModel): + JobId: str = Field(...) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py new file mode 100644 index 000000000..b3a736643 --- /dev/null +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -0,0 +1,297 @@ +import os + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.hunyuan3d import ( + Hunyuan3DViewImage, + InputGenerateType, + ResultFile3D, + To3DProTaskCreateResponse, + To3DProTaskQueryRequest, + To3DProTaskRequest, + To3DProTaskResultResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_bytesio, + downscale_image_tensor_by_max_side, + poll_op, + sync_op, + upload_image_to_comfyapi, + validate_image_dimensions, + validate_string, +) +from folder_paths import get_output_directory + + +def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D: + for i in response_objs: + if i.Type.lower() == "glb": + return i + raise ValueError("No GLB file found in response. Please report this to the developers.") + + +class TencentTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentTextToModelNode", + display_name="Hunyuan3D: Text to Model (Pro)", + category="api node/3d/Tencent", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + face_count: int, + generate_type: InputGenerateType, + seed: int, + ) -> IO.NodeOutput: + _ = seed + validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + Prompt=prompt, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + model_file = f"hunyuan_model_{response.JobId}.glb" + await download_url_to_bytesio( + get_glb_obj_from_response(result.ResultFile3Ds).Url, + os.path.join(get_output_directory(), model_file), + ) + return IO.NodeOutput(model_file) + + +class TencentImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentImageToModelNode", + display_name="Hunyuan3D: Image(s) to Model (Pro)", + category="api node/3d/Tencent", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["generate_type", "generate_type.pbr", "face_count"], + inputs=["image_left", "image_right", "image_back"], + ), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $multiview := ( + inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected + ) ? 10 : 0; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + face_count: int, + generate_type: InputGenerateType, + seed: int, + image_left: Input.Image | None = None, + image_right: Input.Image | None = None, + image_back: Input.Image | None = None, + ) -> IO.NodeOutput: + _ = seed + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + validate_image_dimensions(image, min_width=128, min_height=128) + multiview_images = [] + for k, v in { + "left": image_left, + "right": image_right, + "back": image_back, + }.items(): + if v is None: + continue + validate_image_dimensions(v, min_width=128, min_height=128) + multiview_images.append( + Hunyuan3DViewImage( + ViewType=k, + ViewImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(v, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + ) + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + ImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(image, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + MultiViewImages=multiview_images if multiview_images else None, + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + model_file = f"hunyuan_model_{response.JobId}.glb" + await download_url_to_bytesio( + get_glb_obj_from_response(result.ResultFile3Ds).Url, + os.path.join(get_output_directory(), model_file), + ) + return IO.NodeOutput(model_file) + + +class TencentHunyuan3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TencentTextToModelNode, + TencentImageToModelNode, + ] + + +async def comfy_entrypoint() -> TencentHunyuan3DExtension: + return TencentHunyuan3DExtension() diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 3ec71530b..739fe1855 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -249,7 +249,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), - max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 87e663845..afc18bb25 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -149,7 +149,6 @@ class OpenAIVideoSora2(IO.ComfyNode): response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) return IO.NodeOutput( diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index c052e7656..8fccde25a 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -203,7 +203,6 @@ class TopazImageEnhance(IO.ComfyNode): progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=60, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 364976000..c3c9ff4bf 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -13,6 +13,7 @@ from .conversions import ( bytesio_to_image_tensor, convert_mask_to_image, downscale_image_tensor, + downscale_image_tensor_by_max_side, image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, @@ -33,6 +34,7 @@ from .download_helpers import ( from .upload_helpers import ( upload_audio_to_comfyapi, upload_file_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, ) @@ -61,6 +63,7 @@ __all__ = [ # Upload helpers "upload_audio_to_comfyapi", "upload_file_to_comfyapi", + "upload_image_to_comfyapi", "upload_images_to_comfyapi", "upload_video_to_comfyapi", # Download helpers @@ -75,6 +78,7 @@ __all__ = [ "bytesio_to_image_tensor", "convert_mask_to_image", "downscale_image_tensor", + "downscale_image_tensor_by_max_side", "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index f372ec7b5..8a1259506 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -141,7 +141,7 @@ async def poll_op( queued_statuses: list[str | int] | None = None, data: BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, @@ -238,7 +238,7 @@ async def poll_op_raw( queued_statuses: list[str | int] | None = None, data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 546741b7b..0e15a0efe 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -144,6 +144,21 @@ def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) return s +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: + """Downscale input image tensor so the largest dimension is at most max_side pixels.""" + samples = image.movedim(-1, 1) + height, width = samples.shape[2], samples.shape[3] + max_dim = max(width, height) + if max_dim <= max_side: + return image + scale_by = max_side / max_dim + new_width = round(width * scale_by) + new_height = round(height * scale_by) + s = common_upscale(samples, new_width, new_height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + def tensor_to_data_uri( image_tensor: torch.Tensor, total_pixels: int = 2048 * 2048, diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 2794be35c..2190f9639 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -88,6 +88,28 @@ async def upload_images_to_comfyapi( return download_urls +async def upload_image_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + mime_type: str | None = None, + wait_label: str | None = "Uploading", + total_pixels: int = 2048 * 2048, +) -> str: + """Uploads a single image to ComfyUI API and returns its download URL.""" + return ( + await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type=mime_type, + wait_label=wait_label, + show_batch_index=False, + total_pixels=total_pixels, + ) + )[0] + + async def upload_audio_to_comfyapi( cls: type[IO.ComfyNode], audio: Input.Audio, From ed6002cb60e0709a493d4d8f56793ce0bce12e7e Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 24 Jan 2026 17:30:40 -0800 Subject: [PATCH 50/58] add support for kwargs inputs to allow arbitrary inputs from frontend (#12063) used to output selected combo index Co-authored-by: Jedrzej Kosinski --- comfy_api/latest/_io.py | 12 ++++++++++++ comfy_extras/nodes_logic.py | 12 ++++++++---- execution.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 2ec8d6e4b..03c77a531 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1383,6 +1383,8 @@ class Schema: """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" enable_expand: bool=False """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + accept_all_inputs: bool=False + """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" def validate(self): '''Validate the schema: @@ -1853,6 +1855,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._NOT_IDEMPOTENT + _ACCEPT_ALL_INPUTS = None + @final + @classproperty + def ACCEPT_ALL_INPUTS(cls): # noqa + if cls._ACCEPT_ALL_INPUTS is None: + cls.GET_SCHEMA() + return cls._ACCEPT_ALL_INPUTS + @final @classmethod def INPUT_TYPES(cls) -> dict[str, dict]: @@ -1891,6 +1901,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._INPUT_IS_LIST = schema.is_input_list if cls._NOT_IDEMPOTENT is None: cls._NOT_IDEMPOTENT = schema.not_idempotent + if cls._ACCEPT_ALL_INPUTS is None: + cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs if cls._RETURN_TYPES is None: output = [] diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 1ed060205..c066064ac 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -104,19 +104,23 @@ class CustomComboNode(io.ComfyNode): category="utils", is_experimental=True, inputs=[io.Combo.Input("choice", options=[])], - outputs=[io.String.Output()] + outputs=[ + io.String.Output(display_name="STRING"), + io.Int.Output(display_name="INDEX"), + ], + accept_all_inputs=True, ) @classmethod - def validate_inputs(cls, choice: io.Combo.Type) -> bool: + def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool: # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. return True @classmethod - def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: - return io.NodeOutput(choice) + def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput: + return io.NodeOutput(choice, index) class DCTestNode(io.ComfyNode): diff --git a/execution.py b/execution.py index 648f204ec..4b4f63c80 100644 --- a/execution.py +++ b/execution.py @@ -175,7 +175,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= continue obj = cached.outputs[output_index] input_data_all[x] = obj - elif input_category is not None: + elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): input_data_all[x] = [input_data] if is_v3: From 635406e283e9c0c8964f2fde3ff1ff4a8b31201e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 24 Jan 2026 19:32:28 -0800 Subject: [PATCH 51/58] Only enable fp16 on z image models that actually support it. (#12065) --- comfy/ldm/lumina/model.py | 1 + comfy/model_detection.py | 4 ++++ comfy/supported_models.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b114d9e31..77d1abc97 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -451,6 +451,7 @@ class NextDiT(nn.Module): device=None, dtype=None, operations=None, + **kwargs, ) -> None: super().__init__() self.dtype = dtype diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b29a033cc..8cea16e50 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) dit_config["z_image_modulation"] = True dit_config["time_scale"] = 1000.0 + try: + dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42 + except Exception: + pass if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 45d913fa6..d25271d6e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1093,7 +1093,7 @@ class ZImage(Lumina2): def __init__(self, unet_config): super().__init__(unet_config) - if comfy.model_management.extended_fp16_support(): + if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False): self.supported_inference_dtypes = self.supported_inference_dtypes.copy() self.supported_inference_dtypes.insert(1, torch.float16) From a97c98068f6301b1f87ce89e7bd942ee2db3155d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:56:22 +0800 Subject: [PATCH 52/58] [Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958) * Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer --- comfy/sd.py | 100 +++++++ comfy/weight_adapter/__init__.py | 8 + comfy/weight_adapter/base.py | 231 +++++++++++++++- comfy/weight_adapter/boft.py | 119 ++++++++- comfy/weight_adapter/bypass.py | 437 +++++++++++++++++++++++++++++++ comfy/weight_adapter/glora.py | 219 +++++++++++++++- comfy/weight_adapter/loha.py | 186 +++++++++++-- comfy/weight_adapter/lokr.py | 311 ++++++++++++++++++++-- comfy/weight_adapter/lora.py | 165 +++++++++++- comfy/weight_adapter/oft.py | 186 ++++++++++++- comfy_extras/nodes_train.py | 111 +++++++- nodes.py | 67 +++++ 12 files changed, 2039 insertions(+), 101 deletions(-) create mode 100644 comfy/weight_adapter/bypass.py diff --git a/comfy/sd.py b/comfy/sd.py index ce7e6bcff..f627f7d55 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert +import comfy.weight_adapter import yaml import math import os @@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): return (new_modelpatcher, new_clip) +def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip): + """ + Load LoRA in bypass mode without modifying base model weights. + + Instead of patching weights, this injects the LoRA computation into the + forward pass: output = base_forward(x) + lora_path(x) + + Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches. + + This is useful for training and when model weights are offloaded. + """ + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries") + + lora = comfy.lora_convert.convert_lora(lora) + loaded = comfy.lora.load_lora(lora, key_map) + + logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries") + + # Separate adapters (for bypass) from other patches (for regular patching) + bypass_patches = {} # WeightAdapterBase instances -> bypass mode + regular_patches = {} # diff, set, bias patches -> regular weight patching + + for key, patch_data in loaded.items(): + if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase): + bypass_patches[key] = patch_data + else: + regular_patches[key] = patch_data + + logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches") + + k = set() + k1 = set() + + if model is not None: + new_modelpatcher = model.clone() + + # Apply regular patches (bias diff, weight diff, etc.) via normal patching + if regular_patches: + patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model) + k.update(patched_keys) + + # Apply adapter patches via bypass injection + manager = comfy.weight_adapter.BypassInjectionManager() + model_sd_keys = set(new_modelpatcher.model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in model_sd_keys: + manager.add_adapter(key, adapter, strength=strength_model) + k.add(key) + else: + logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}") + + injections = manager.create_injections(new_modelpatcher.model) + + if manager.get_hook_count() > 0: + new_modelpatcher.set_injections("bypass_lora", injections) + else: + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + + # Apply regular patches to clip + if regular_patches: + patched_keys = new_clip.add_patches(regular_patches, strength_clip) + k1.update(patched_keys) + + # Apply adapter patches via bypass injection + clip_manager = comfy.weight_adapter.BypassInjectionManager() + clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in clip_sd_keys: + clip_manager.add_adapter(key, adapter, strength=strength_clip) + k1.add(key) + + clip_injections = clip_manager.create_injections(new_clip.cond_stage_model) + if clip_manager.get_hook_count() > 0: + new_clip.patcher.set_injections("bypass_lora", clip_injections) + else: + new_clip = None + + for x in loaded: + if (x not in k) and (x not in k1): + patch_data = loaded[x] + patch_type = type(patch_data).__name__ + if isinstance(patch_data, tuple): + patch_type = f"tuple({patch_data[0]})" + logging.warning(f"NOT LOADED: {x} (type={patch_type})") + + return (new_modelpatcher, new_clip) + + class CLIP: def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index b40f920e4..b9fa8d5cf 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -5,6 +5,11 @@ from .lokr import LoKrAdapter from .glora import GLoRAAdapter from .oft import OFTAdapter from .boft import BOFTAdapter +from .bypass import ( + BypassInjectionManager, + BypassForwardHook, + create_bypass_injections_from_patches, +) adapters: list[type[WeightAdapterBase]] = [ @@ -31,4 +36,7 @@ __all__ = [ "WeightAdapterTrainBase", "adapters", "adapter_maps", + "BypassInjectionManager", + "BypassForwardHook", + "create_bypass_injections_from_patches", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 43644b106..bce89a0e2 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch import torch.nn as nn @@ -7,12 +7,35 @@ import comfy.model_management class WeightAdapterBase: + """ + Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT/BOFT: g = transform, h = 0 + """ + name: str loaded_keys: set[str] weights: list[torch.Tensor] + # Attributes set by bypass system + multiplier: float = 1.0 + shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel) + @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + ) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -39,18 +62,202 @@ class WeightAdapterBase: ): raise NotImplementedError + # ===== Bypass Mode Methods ===== + # + # IMPORTANT: Bypass mode is designed for quantized models where original weights + # may not be accessible in a usable format. Therefore, h() and bypass_forward() + # do NOT take org_weight as a parameter. All necessary information (out_channels, + # in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook. + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT/BOFT), returns zeros. + + Note: + This method does NOT access original model weights. Bypass mode is + designed for quantized models where weights may not be in a usable format. + All shape info comes from module attributes set by BypassForwardHook. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Reference: LyCORIS LoConModule.bypass_forward_diff + """ + # Default: no additive component (for OFT/BOFT) + # Simply return zeros matching base_out shape + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT/BOFT override this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + + Reference: LyCORIS OFTModule applies orthogonal transform here + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Note: + This method does NOT take org_weight/org_bias parameters. Bypass mode + is designed for quantized models where weights may not be accessible. + The original forward function handles weight access internally. + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + + Reference: LyCORIS LoConModule.bypass_forward + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + class WeightAdapterTrainBase(nn.Module): - # We follow the scheme of PR #7032 + """ + Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT: g = transform, h = 0 + + Note: + Unlike WeightAdapterBase, TrainBase classes have simplified weight formats + with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition). + + We follow the scheme of PR #7032 + """ + + # Attributes set by bypass system (BypassForwardHook) + # These are set before h()/g()/bypass_forward() are called + multiplier: float = 1.0 + is_conv: bool = False + conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d + kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups + kernel_size: tuple = () + in_channels: int = None + out_channels: int = None + def __init__(self): super().__init__() def __call__(self, w): """ - w: The original weight tensor to be modified. + Weight modification mode: returns modified weight. + + Args: + w: The original weight tensor to be modified. + + Returns: + Modified weight tensor. """ raise NotImplementedError + # ===== Bypass Mode Methods ===== + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT), returns zeros. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Subclasses should override this method. + """ + raise NotImplementedError( + f"{self.__class__.__name__}.h() not implemented. " + "Subclasses must implement h() for bypass mode." + ) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT overrides this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + def passive_memory_usage(self): raise NotImplementedError("passive_memory_usage is not implemented") @@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module): return self.passive_memory_usage() -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) +def weight_decompose( + dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function +): + dora_scale = comfy.model_management.cast_to_device( + dora_scale, weight.device, intermediate_dtype + ) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) @@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten the original tensor will be truncated in that dimension. """ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): - raise ValueError("The new shape must be larger than the original tensor in all dimensions") + raise ValueError( + "The new shape must be larger than the original tensor in all dimensions" + ) if len(new_shape) != len(tensor.shape): - raise ValueError("The new shape must have the same number of dimensions as the original tensor") + raise ValueError( + "The new shape must have the same number of dimensions as the original tensor" + ) # Create a new tensor filled with zeros padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd4..02a8dc130 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase): alpha = v[2] dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) boft_m, block_num, boft_b, *_ = blocks.shape @@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(-1, -2) normed_q = q - if alpha > 0: # alpha in boft/bboft is for constraint + if alpha > 0: # alpha in boft/bboft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm @@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase): r = r.to(weight) inp = org = weight - r_b = boft_b//2 + r_b = boft_b // 2 for i in range(boft_m): bi = r[i] g = 2 k = 2**i * r_b if strength != 1: - bi = bi * strength + (1-strength) * I + bi = bi * strength + (1 - strength) * I inp = ( inp.unflatten(0, (-1, g, k)) .transpose(1, 2) @@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase): ) inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) inp = ( - inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) + inp.flatten(0, 1) + .unflatten(0, (-1, k, g)) + .transpose(1, 2) + .flatten(0, 2) ) if rescale is not None: inp = inp * rescale lora_diff = inp - org - lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + lora_diff = comfy.model_management.cast_to_device( + lora_diff, weight.device, intermediate_dtype + ) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrices(self, device, dtype): + """Compute the orthogonal rotation matrices R from BOFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + boft_m, block_num, boft_b, _ = blocks.shape + I = torch.eye(boft_b, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(-1, -2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, boft_m, boft_b + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for BOFT: applies butterfly orthogonal transform. + + BOFT uses multiple stages of butterfly-structured orthogonal transforms. + + Reference: LyCORIS ButterflyOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype) + r_b = boft_b // 2 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(boft_b, device=y.device, dtype=y.dtype) + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # Apply butterfly transform stages + inp = y + for i in range(boft_m): + bi = r[i] # (block_num, boft_b, boft_b) + g = 2 + k = 2**i * r_b + + # Interpolate with identity based on multiplier + if multiplier != 1: + bi = bi * multiplier + (1 - multiplier) * I + + # Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + # Apply block-diagonal orthogonal transform + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # Reshape back + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + inp = inp * rescale.transpose(0, -1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + inp = inp.transpose(1, -1) + + return inp diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py new file mode 100644 index 000000000..d4aaf98ca --- /dev/null +++ b/comfy/weight_adapter/bypass.py @@ -0,0 +1,437 @@ +""" +Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.) + +Bypass mode applies adapters during forward pass without modifying base weights: + bypass(f)(x) = g(f(x) + h(x)) + +Where: + - f(x): Original layer forward + - h(x): Additive component from adapter (LoRA path) + - g(y): Output transformation (identity for most adapters) + +This is useful for: + - Training with gradient checkpointing + - Avoiding weight modifications when weights are offloaded + - Supporting multiple adapters with different strengths dynamically +""" + +import logging +from typing import Optional, Union + +import torch +import torch.nn as nn + +from .base import WeightAdapterBase, WeightAdapterTrainBase +from comfy.patcher_extension import PatcherInjection + +# Type alias for adapters that support bypass mode +BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] + + +def get_module_type_info(module: nn.Module) -> dict: + """ + Determine module type and extract conv parameters from module class. + + This is more reliable than checking weight.ndim, especially for quantized layers + where weight shape might be different. + + Returns: + dict with keys: is_conv, conv_dim, stride, padding, dilation, groups + """ + info = { + "is_conv": False, + "conv_dim": 0, + "stride": (1,), + "padding": (0,), + "dilation": (1,), + "groups": 1, + "kernel_size": (1,), + "in_channels": None, + "out_channels": None, + } + + # Determine conv type + if isinstance(module, nn.Conv1d): + info["is_conv"] = True + info["conv_dim"] = 1 + elif isinstance(module, nn.Conv2d): + info["is_conv"] = True + info["conv_dim"] = 2 + elif isinstance(module, nn.Conv3d): + info["is_conv"] = True + info["conv_dim"] = 3 + elif isinstance(module, nn.Linear): + info["is_conv"] = False + info["conv_dim"] = 0 + else: + # Try to infer from class name for custom/quantized layers + class_name = type(module).__name__.lower() + if "conv3d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 3 + elif "conv2d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + elif "conv1d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 1 + elif "conv" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + + # Extract conv parameters if it's a conv layer + if info["is_conv"]: + # Try to get stride, padding, dilation, groups, kernel_size from module + info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"]) + info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"]) + info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"]) + info["groups"] = getattr(module, "groups", 1) + info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"]) + info["in_channels"] = getattr(module, "in_channels", None) + info["out_channels"] = getattr(module, "out_channels", None) + + # Ensure they're tuples + if isinstance(info["stride"], int): + info["stride"] = (info["stride"],) * info["conv_dim"] + if isinstance(info["padding"], int): + info["padding"] = (info["padding"],) * info["conv_dim"] + if isinstance(info["dilation"], int): + info["dilation"] = (info["dilation"],) * info["conv_dim"] + if isinstance(info["kernel_size"], int): + info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"] + + return info + + +class BypassForwardHook: + """ + Hook that wraps a layer's forward to apply adapter in bypass mode. + + Stores the original forward and replaces it with bypass version. + + Supports both: + - WeightAdapterBase: Inference adapters (uses self.weights tuple) + - WeightAdapterTrainBase: Training adapters (nn.Module with parameters) + """ + + def __init__( + self, + module: nn.Module, + adapter: BypassAdapter, + multiplier: float = 1.0, + ): + self.module = module + self.adapter = adapter + self.multiplier = multiplier + self.original_forward = None + + # Determine layer type and conv params from module class (works for quantized layers) + module_info = get_module_type_info(module) + + # Set multiplier and layer type info on adapter for use in h() + adapter.multiplier = multiplier + adapter.is_conv = module_info["is_conv"] + adapter.conv_dim = module_info["conv_dim"] + adapter.kernel_size = module_info["kernel_size"] + adapter.in_channels = module_info["in_channels"] + adapter.out_channels = module_info["out_channels"] + # Store kw_dict for conv operations (like LyCORIS extra_args) + if module_info["is_conv"]: + adapter.kw_dict = { + "stride": module_info["stride"], + "padding": module_info["padding"], + "dilation": module_info["dilation"], + "groups": module_info["groups"], + } + else: + adapter.kw_dict = {} + + def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x)) + + Note: + Bypass mode does NOT access original model weights (org_weight). + This is intentional - bypass mode is designed for quantized models + where weights may not be in a usable format. All necessary shape + information is provided via adapter attributes set during inject(). + """ + # Check if adapter has custom bypass_forward (e.g., GLoRA) + adapter_bypass = getattr(self.adapter, "bypass_forward", None) + if adapter_bypass is not None: + # Check if it's overridden (not the base class default) + # Need to check both base classes since adapter could be either type + adapter_type = type(self.adapter) + is_default_bypass = ( + adapter_type.bypass_forward is WeightAdapterBase.bypass_forward + or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward + ) + if not is_default_bypass: + return adapter_bypass(self.original_forward, x, *args, **kwargs) + + # Default bypass: g(f(x) + h(x, f(x))) + base_out = self.original_forward(x, *args, **kwargs) + h_out = self.adapter.h(x, base_out) + return self.adapter.g(base_out + h_out) + + def inject(self): + """Replace module forward with bypass version.""" + if self.original_forward is not None: + logging.debug( + f"[BypassHook] Already injected for {type(self.module).__name__}" + ) + return # Already injected + + # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward + device = None + dtype = None + if hasattr(self.module, "weight") and self.module.weight is not None: + device = self.module.weight.device + dtype = self.module.weight.dtype + elif hasattr(self.module, "W_q"): # Quantized layers might use different attr + device = self.module.W_q.device + dtype = self.module.W_q.dtype + + if device is not None: + self._move_adapter_weights_to_device(device, dtype) + + self.original_forward = self.module.forward + self.module.forward = self._bypass_forward + logging.debug( + f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" + ) + + def _move_adapter_weights_to_device(self, device, dtype=None): + """Move adapter weights to specified device to avoid per-forward transfers. + + Handles both: + - WeightAdapterBase: has self.weights tuple of tensors + - WeightAdapterTrainBase: nn.Module with parameters, uses .to() method + """ + adapter = self.adapter + + # Check if adapter is an nn.Module (WeightAdapterTrainBase) + if isinstance(adapter, nn.Module): + # In training mode we don't touch dtype as trainer will handle it + adapter.to(device=device) + logging.debug( + f"[BypassHook] Moved training adapter (nn.Module) to {device}" + ) + return + + # WeightAdapterBase: handle self.weights tuple + if not hasattr(adapter, "weights") or adapter.weights is None: + return + + weights = adapter.weights + if isinstance(weights, (list, tuple)): + new_weights = [] + for w in weights: + if isinstance(w, torch.Tensor): + if dtype is not None: + new_weights.append(w.to(device=device, dtype=dtype)) + else: + new_weights.append(w.to(device=device)) + else: + new_weights.append(w) + adapter.weights = ( + tuple(new_weights) if isinstance(weights, tuple) else new_weights + ) + elif isinstance(weights, torch.Tensor): + if dtype is not None: + adapter.weights = weights.to(device=device, dtype=dtype) + else: + adapter.weights = weights.to(device=device) + + logging.debug(f"[BypassHook] Moved adapter weights to {device}") + + def eject(self): + """Restore original module forward.""" + if self.original_forward is None: + logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + return # Not injected + + self.module.forward = self.original_forward + self.original_forward = None + logging.debug( + f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" + ) + + +class BypassInjectionManager: + """ + Manages bypass mode injection for a collection of adapters. + + Creates PatcherInjection objects that can be used with ModelPatcher. + + Supports both inference adapters (WeightAdapterBase) and training adapters + (WeightAdapterTrainBase). + + Usage: + manager = BypassInjectionManager() + manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8) + manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8) + + injections = manager.create_injections(model) + model_patcher.set_injections("bypass_lora", injections) + """ + + def __init__(self): + self.adapters: dict[str, tuple[BypassAdapter, float]] = {} + self.hooks: list[BypassForwardHook] = [] + + def add_adapter( + self, + key: str, + adapter: BypassAdapter, + strength: float = 1.0, + ): + """ + Add an adapter for a specific weight key. + + Args: + key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight") + adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.) + strength: Multiplier for adapter effect + """ + # Remove .weight suffix if present for module lookup + module_key = key + if module_key.endswith(".weight"): + module_key = module_key[:-7] + logging.debug( + f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" + ) + + self.adapters[module_key] = (adapter, strength) + logging.debug( + f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" + ) + + def clear_adapters(self): + """Remove all adapters.""" + self.adapters.clear() + + def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]: + """Get a submodule by dot-separated key.""" + parts = key.split(".") + module = model + try: + for i, part in enumerate(parts): + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + logging.debug( + f"[BypassManager] Found module for key {key}: {type(module).__name__}" + ) + return module + except (AttributeError, IndexError, KeyError) as e: + logging.error(f"[BypassManager] Failed to find module for key {key}: {e}") + logging.error( + f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}" + ) + return None + + def create_injections(self, model: nn.Module) -> list[PatcherInjection]: + """ + Create PatcherInjection objects for all registered adapters. + + Args: + model: The model to inject into (e.g., model_patcher.model) + + Returns: + List of PatcherInjection objects to use with model_patcher.set_injections() + """ + self.hooks.clear() + + logging.debug( + f"[BypassManager] create_injections called with {len(self.adapters)} adapters" + ) + logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + + for key, (adapter, strength) in self.adapters.items(): + logging.debug(f"[BypassManager] Looking for module: {key}") + module = self._get_module_by_key(model, key) + + if module is None: + logging.warning(f"[BypassManager] Module not found for key {key}") + continue + + if not hasattr(module, "weight"): + logging.warning( + f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})" + ) + continue + + logging.debug( + f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" + ) + hook = BypassForwardHook(module, adapter, multiplier=strength) + self.hooks.append(hook) + + logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + + # Create single injection that manages all hooks + def inject_all(model_patcher): + logging.debug( + f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.inject() + logging.debug( + f"[BypassManager] Injected hook for {type(hook.module).__name__}" + ) + + def eject_all(model_patcher): + logging.debug( + f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.eject() + + return [PatcherInjection(inject=inject_all, eject=eject_all)] + + def get_hook_count(self) -> int: + """Return number of hooks that will be/are injected.""" + return len(self.hooks) + + +def create_bypass_injections_from_patches( + model: nn.Module, + patches: dict, + strength: float = 1.0, +) -> list[PatcherInjection]: + """ + Convenience function to create bypass injections from a patches dict. + + This is useful when you have patches in the format used by model_patcher.add_patches() + and want to apply them in bypass mode instead. + + Args: + model: The model to inject into + patches: Dict mapping weight keys to adapter data + strength: Global strength multiplier + + Returns: + List of PatcherInjection objects + """ + manager = BypassInjectionManager() + + for key, patch_list in patches.items(): + if not patch_list: + continue + + # patches format: list of (strength_patch, patch_data, strength_model, offset, function) + for patch in patch_list: + patch_strength, patch_data, strength_model, offset, function = patch + + # patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple + if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)): + adapter = patch_data + else: + # Skip non-adapter patches + continue + + combined_strength = strength * patch_strength + manager.add_adapter(key, adapter, strength=combined_strength) + + return manager.create_injections(model) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5..d6b97a23b 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,8 @@ import logging -from typing import Optional +from typing import Callable, Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, weight_decompose @@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + weights = ( + lora[a1_name], + lora[a2_name], + lora[b1_name], + lora[b2_name], + alpha, + dora_scale, + ) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase): old_glora = True if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + if ( + old_glora + and v[1].shape[0] == weight.shape[0] + and weight.shape[0] == weight.shape[1] + ): pass else: old_glora = False rank = v[1].shape[0] - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + a1 = comfy.model_management.cast_to_device( + v[0].flatten(start_dim=1), weight.device, intermediate_dtype + ) + a2 = comfy.model_management.cast_to_device( + v[1].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b1 = comfy.model_management.cast_to_device( + v[2].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b2 = comfy.model_management.cast_to_device( + v[3].flatten(start_dim=1), weight.device, intermediate_dtype + ) if v[4] is not None: alpha = v[4] / rank @@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase): try: if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + lora_diff = ( + torch.mm(b2, b1) + + torch.mm( + torch.mm( + weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2 + ), + a1, + ) + ).reshape( + weight.shape + ) # old lycoris glora else: if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.einsum( + "o i ..., i j -> o j ...", + torch.einsum( + "o i ..., i j -> o j ...", + weight.to(dtype=intermediate_dtype), + a1, + ), + a2, + ).reshape(weight.shape) else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.mm( + torch.mm(weight.to(dtype=intermediate_dtype), a1), a2 + ).reshape(weight.shape) lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _compute_paths(self, x: torch.Tensor): + """ + Compute A path and B path outputs for GLoRA bypass. + + GLoRA: f(x) = Wx + WAx + Bx + - A path: a1(a2(x)) - modifies input to base forward + - B path: b1(b2(x)) - additive component + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Returns: (a_out, b_out) + """ + v = self.weights + # v = (a1, a2, b1, b2, alpha, dora_scale) + a1 = v[0] + a2 = v[1] + b1 = v[2] + b2 = v[3] + alpha = v[4] + + dtype = x.dtype + + # Cast dtype (weights should already be on correct device from inject()) + a1 = a1.to(dtype=dtype) + a2 = a2.to(dtype=dtype) + b1 = b1.to(dtype=dtype) + b2 = b2.to(dtype=dtype) + + # Determine rank and scale + # Check for old vs new glora format + old_glora = False + if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]: + rank = a1.shape[0] + old_glora = True + + if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]: + if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]: + pass + else: + old_glora = False + rank = a2.shape[0] + + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + scale = scale * multiplier + + # Use module info from bypass injection, not input tensor shape + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + # Conv case - conv_dim is 1/2/3 for conv1d/2d/3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Get module's stride/padding for spatial dimension handling + module_stride = kw_dict.get("stride", (1,) * conv_dim) + module_padding = kw_dict.get("padding", (0,) * conv_dim) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Ensure weights are in conv shape + # a1, a2, b1 are always 1x1 kernels + if a1.ndim == 2: + a1 = a1.view(*a1.shape, *([1] * conv_dim)) + if a2.ndim == 2: + a2 = a2.view(*a2.shape, *([1] * conv_dim)) + if b1.ndim == 2: + b1 = b1.view(*b1.shape, *([1] * conv_dim)) + # b2 has actual kernel_size (like LoRA down) + if b2.ndim == 2: + if in_channels is not None: + b2 = b2.view(b2.shape[0], in_channels, *kernel_size) + else: + b2 = b2.view(*b2.shape, *([1] * conv_dim)) + + # A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x + a2_out = conv_fn(x, a2) + a_out = conv_fn(a2_out, a1) * scale + + # B path: b2(x) with kernel/stride/padding -> b1(...) 1x1 + b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding) + b_out = conv_fn(b2_out, b1) * scale + else: + # Linear case + if old_glora: + # Old format: a1 @ a2 @ x, b2 @ b1 + a_out = F.linear(F.linear(x, a2), a1) * scale + b_out = F.linear(F.linear(x, b1), b2) * scale + else: + # New format: x @ a1 @ a2, b1 @ b2 + a_out = F.linear(F.linear(x, a1), a2) * scale + b_out = F.linear(F.linear(x, b2), b1) * scale + + return a_out, b_out + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + GLoRA bypass forward: f(x + a(x)) + b(x) + + Unlike standard adapters, GLoRA modifies the input to the base forward + AND adds the B path output. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Reference: LyCORIS GLoRAModule._bypass_forward + """ + a_out, b_out = self._compute_paths(x) + + # Call base forward with modified input + base_out = org_forward(x + a_out, *args, **kwargs) + + # Add B path + return base_out + b_out + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + For GLoRA, h() returns the B path output. + + Note: + GLoRA's full bypass requires overriding bypass_forward() since + it also modifies the input to org_forward. This h() is provided for + compatibility but bypass_forward() should be used for correct behavior. + + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + _, b_out = self._compute_paths(x) + return b_out diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d403..8007b7b44 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,11 +1,22 @@ import logging +from functools import cache from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose +@cache +def _warn_loha_bypass_inefficient(): + """One-time warning about LoHa bypass inefficiency.""" + logging.warning( + "LoHa bypass mode is inefficient: full weight diff is computed each forward pass. " + "Consider using LoRA or LoKr for training with bypass mode." + ) + + class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): @@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase): scale = self.alpha / self.rank if self.use_tucker: - diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeightTucker.apply( + self.hada_t1, + self.hada_w1_a, + self.hada_w1_b, + self.hada_t2, + self.hada_w2_a, + self.hada_w2_b, + scale, + ) else: - diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeight.apply( + self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale + ) # Add the scaled difference to the original weight weight = w.to(diff_weight) + diff_weight.reshape(w.shape) @@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase): mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) - return LohaDiff( - (mat1, mat2, alpha, mat3, mat4, None, None, None) - ) + return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None)) def to_train(self): return LohaDiff(self.weights) @@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + weights = ( + lora[hada_w1_a_name], + lora[hada_w1_b_name], + alpha, + lora[hada_w2_a_name], + lora[hada_w2_b_name], + hada_t1, + hada_t2, + dora_scale, + ) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase): w2a = v[3] w2b = v[4] dora_scale = v[7] - if v[5] is not None: #cp decomposition + if v[5] is not None: # cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + m1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t1, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + ) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + m2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + ) else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + m1 = torch.mm( + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + ) + m2 = torch.mm( + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + ) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoHa: h(x) = diff_weight @ x + + WARNING: Inefficient - computes full Hadamard product each forward. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/loha.py bypass_forward_diff + """ + _warn_loha_bypass_inefficient() + + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora + w1a = v[0] + w1b = v[1] + alpha = v[2] + w2a = v[3] + w2b = v[4] + t1 = v[5] + t2 = v[6] + + # Compute scale + rank = w1b.shape[0] + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Cast dtype + w1a = w1a.to(dtype=x.dtype) + w1b = w1b.to(dtype=x.dtype) + w2a = w2a.to(dtype=x.dtype) + w2b = w2b.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Compute diff weight using Hadamard product + if t1 is not None and t2 is not None: + t1 = t1.to(dtype=x.dtype) + t2 = t2.to(dtype=x.dtype) + m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a) + m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a) + diff_weight = (m1 * m2) * scale + else: + m1 = w1a @ w1b + m2 = w2a @ w2b + diff_weight = (m1 * m2) * scale + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D diff_weight to conv format using kernel_size + # diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size] + if diff_weight.dim() == 2: + if in_channels is not None: + diff_weight = diff_weight.view( + diff_weight.shape[0], in_channels, *kernel_size + ) + else: + diff_weight = diff_weight.view( + *diff_weight.shape, *([1] * conv_dim) + ) + else: + op = F.linear + kw_dict = {} + + return op(x, diff_weight, **kw_dict) diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7..b83750012 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -14,7 +15,17 @@ from .base import ( class LokrDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) = weights self.use_tucker = False if lokr_w1_a is not None: _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] @@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase): if self.w2_rebuild: if self.use_tucker: w2 = torch.einsum( - 'i j k l, j r, i p -> p r k l', + "i j k l, j r, i p -> p r k l", self.lokr_t2, self.lokr_w2_b, - self.lokr_w2_a + self.lokr_w2_a, ) else: w2 = self.lokr_w2_a @ self.lokr_w2_b @@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase): return self.lokr_w2 def __call__(self, w): - diff = torch.kron(self.w1, self.w2) + w1 = self.w1 + w2 = self.w2 + # Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron) + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + diff = torch.kron(w1, w2) return w + diff.reshape(w.shape).to(w) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr training: efficient Kronecker product. + + Uses w1/w2 properties which handle both direct and decomposed cases. + For create_train (direct w1/w2), no alpha scaling in properties. + For to_train (decomposed), alpha/rank scaling is in properties. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Get w1, w2 from properties (handles rebuild vs direct) + w1 = self.w1 + w2 = self.w2 + + # Multiplier from bypass injection + multiplier = getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Efficient Kronecker application without materializing full weight + # kron(w1, w2) @ x can be computed as nested operations + # w1: [out_l, in_m], w2: [out_k, in_n, *k_size] + # Full weight would be [out_l*out_k, in_m*in_n, *k_size] + + uq = w1.size(1) # in_m - inner grouping dimension + + if is_conv: + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + B, C_in, *spatial = x.shape + # Reshape input for grouped application: [B * uq, C_in // uq, *spatial] + h_in_group = x.reshape(B * uq, -1, *spatial) + + # Ensure w2 has conv dims + if w2.dim() == 2: + w2 = w2.view(*w2.shape, *([1] * conv_dim)) + + # Apply w2 path with stride/padding + hb = conv_fn(h_in_group, w2, **kw_dict) + + # Reshape for cross-group operation + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross = hb.transpose(1, -1) + + # Apply w1 (always 2D, applied as linear on channel dim) + hc = F.linear(h_cross, w1) + hc = hc.transpose(1, -1) + + # Reshape to output + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + # Linear case + # Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n] + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k] + hb = F.linear(h_in_group, w2) + + # Transpose for w1: [..., uq, out_k] -> [..., out_k, uq] + h_cross = hb.transpose(-1, -2) + + # Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l] + hc = F.linear(h_cross, w1) + + # Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k] + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * multiplier + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase): @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] - in_dim = weight.shape[1:].numel() - out1, out2 = factorization(out_dim, rank) - in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) + in_dim = weight.shape[1] # Just in_channels, not flattened with kernel + k_size = weight.shape[2:] if weight.dim() > 2 else () + + out_l, out_k = factorization(out_dim, rank) + in_m, in_n = factorization(in_dim, rank) + + # w1: [out_l, in_m] + mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32) + # w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear + mat2 = torch.empty( + out_k, in_n, *k_size, device=weight.device, dtype=torch.float32 + ) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) - return LokrDiff( - (mat1, mat2, alpha, None, None, None, None, None, None) - ) + return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None)) def to_train(self): return LokrDiff(self.weights) @@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + if ( + (lokr_w1 is not None) + or (lokr_w2 is not None) + or (lokr_w1_a is not None) + or (lokr_w2_a is not None) + ): + weights = ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) return cls(loaded_keys, weights) else: return None @@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + w1 = torch.mm( + comfy.model_management.cast_to_device( + w1_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1_b, weight.device, intermediate_dtype + ), + ) else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + w1 = comfy.model_management.cast_to_device( + w1, weight.device, intermediate_dtype + ) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + w2 = torch.mm( + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + ) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + w2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + ) else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + w2 = comfy.model_management.cast_to_device( + w2, weight.device, intermediate_dtype + ) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr: efficient Kronecker product application. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/lokr.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora + w1 = v[0] + w2 = v[1] + alpha = v[2] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t2 is not None + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) if is_conv else {} + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + else: + op = F.linear + + # Determine rank and scale + rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Build c (w1) + if use_w1: + c = w1.to(dtype=x.dtype) + else: + c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype) + uq = c.size(1) + + # Build w2 components + if use_w2: + ba = w2.to(dtype=x.dtype) + else: + a = w2_b.to(dtype=x.dtype) + b = w2_a.to(dtype=x.dtype) + if is_conv: + if tucker: + # Tucker: a, b get 1s appended (kernel is in t2) + if a.dim() == 2: + a = a.view(*a.shape, *([1] * conv_dim)) + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + else: + # Non-tucker conv: b may need 1s appended + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + + # Reshape input by uq groups + if is_conv: + B, _, *rest = x.shape + h_in_group = x.reshape(B * uq, -1, *rest) + else: + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2 path + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + t = t2.to(dtype=x.dtype) + if t.dim() == 2: + t = t.view(*t.shape, *([1] * conv_dim)) + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a) + hb = op(ha, b) + + # Reshape and apply c (w1) + if is_conv: + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + + if is_conv: + hc = hc.transpose(1, -1) + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * scale diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b..bc4260a8f 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase): rank, in_dim = mat2.shape[0], mat2.shape[1] if mid is not None: convdim = mid.ndim - 2 - layer = ( - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d - )[convdim] + layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim] else: layer = torch.nn.Linear self.lora_up = layer(rank, out_dim, bias=False) @@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase): weight = w + scale * diff.reshape(w.shape) return weight.to(org_dtype) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA training: h(x) = up(down(x)) * scale + + Simple implementation using the nn.Module weights directly. + No mid/dora/reshape branches (create_train doesn't create them). + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Compute scale = alpha / rank * multiplier + scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Get weights (keep in original dtype for numerical stability) + down_weight = self.lora_down.weight + up_weight = self.lora_up.weight + + if is_conv: + # Conv path: use functional conv + # conv_dim: 1=conv1d, 2=conv2d, 3=conv3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Reshape 2D weights to conv format if needed + # down: [rank, in_features] -> [rank, in_channels, *kernel_size] + # up: [out_features, rank] -> [out_features, rank, 1, 1, ...] + if down_weight.dim() == 2: + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + if in_channels is not None: + down_weight = down_weight.view( + down_weight.shape[0], in_channels, *kernel_size + ) + else: + # Fallback: assume 1x1 kernel + down_weight = down_weight.view( + *down_weight.shape, *([1] * conv_dim) + ) + if up_weight.dim() == 2: + # up always uses 1x1 kernel + up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim)) + + # down conv uses stride/padding from module, up is 1x1 + hidden = conv_fn(x, down_weight, **kw_dict) + + # mid layer if exists (tucker decomposition) + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + if mid_weight.dim() == 2: + mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim)) + hidden = conv_fn(hidden, mid_weight) + + # up conv is always 1x1 (no stride/padding) + out = conv_fn(hidden, up_weight) + else: + # Linear path: simple matmul chain + hidden = F.linear(x, down_weight) + + # mid layer if exists + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + hidden = F.linear(hidden, mid_weight) + + out = F.linear(hidden, up_weight) + + return out * scale + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase): mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) - return LoraDiff( - (mat1, mat2, alpha, None, None, None) - ) + return LoraDiff((mat1, mat2, alpha, None, None, None)) def to_train(self): return LoraDiff(self.weights) @@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase): except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA: h(x) = up(down(x)) * scale + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/locon.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape + up = v[0] + down = v[1] + alpha = v[2] + mid = v[3] + + # Compute scale = alpha / rank + rank = down.shape[0] + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + scale = scale * getattr(self, "multiplier", 1.0) + + # Cast dtype + up = up.to(dtype=x.dtype) + down = down.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + op = FUNC_LIST[ + conv_dim + 2 + ] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D weights to conv format using kernel_size + # down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size] + # up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel) + if down.dim() == 2: + # down.shape[1] = in_channels * prod(kernel_size) + if in_channels is not None: + down = down.view(down.shape[0], in_channels, *kernel_size) + else: + # Fallback: assume 1x1 kernel if in_channels unknown + down = down.view(*down.shape, *([1] * conv_dim)) + if up.dim() == 2: + # up always uses 1x1 kernel + up = up.view(*up.shape, *([1] * conv_dim)) + if mid is not None: + mid = mid.to(dtype=x.dtype) + if mid.dim() == 2: + mid = mid.view(*mid.shape, *([1] * conv_dim)) + else: + op = F.linear + kw_dict = {} # linear doesn't take stride/padding + + # Simple chain: down -> mid (if tucker) -> up + if mid is not None: + if not is_conv: + mid = mid.to(dtype=x.dtype) + hidden = op(x, down) + hidden = op(hidden, mid, **kw_dict) + out = op(hidden, up) + else: + hidden = op(x, down, **kw_dict) + out = op(hidden, up) + + return out * scale diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab9635..bc83cf8e8 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,13 +3,18 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) class OFTDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - # Unpack weights tuple from LoHaAdapter + # Unpack weights tuple from OFTAdapter blocks, rescale, alpha, _ = weights # Create trainable parameters @@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase): weight = self.rescale * weight return weight.to(org_dtype) + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + blocks = self.oft_blocks.to(device=device, dtype=dtype) + I = torch.eye(self.block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if set + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r.to(dtype) + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + OFT has no additive component - returns zeros matching base_out shape. + + OFT only transforms the output via g(), it doesn't add to it. + """ + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation. + + OFT transforms output channels using block-diagonal orthogonal matrices. + """ + r = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(self.block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.reshape(*batch_shape, out_features) + + # Apply rescale if present + if self.rescaled: + rescale = self.rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out + def passive_memory_usage(self): """Calculates memory usage of the trainable parameters.""" return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) - return OFTDiff( - (block, None, alpha, None) + block = torch.zeros( + block_num, block_size, block_size, device=weight.device, dtype=torch.float32 ) + return OFTDiff((block, None, alpha, None)) def to_train(self): return OFTDiff(self.weights) @@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase): alpha = 0 dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) block_num, block_size, *_ = blocks.shape @@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(1, 2) normed_q = q - if alpha > 0: # alpha in oft/boft is for constraint + if alpha > 0: # alpha in oft/boft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() r = r.to(weight) + # Create I in weight's dtype for the einsum + I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype) _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", - (r * strength) - strength * I, + (r * strength) - strength * I_w, weight.view(block_num, block_size, *shape), ).view(-1, *shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + block_num, block_size, _ = blocks.shape + I = torch.eye(block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, block_num, block_size + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation to output. + + OFT transforms the output channels using block-diagonal orthogonal matrices. + + Reference: LyCORIS DiagOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.view(*batch_shape, block_num, block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.view(*batch_shape, out_features) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 68a73cf13..024a89391 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps +from comfy.weight_adapter.bypass import BypassInjectionManager from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar @@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + for param_groups in self.optimizer.param_groups: + for param in param_groups["params"]: + if param.grad is None: + continue + param.grad.data = param.grad.data.to(param.data.dtype) self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) @@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): num_images = sum(t.shape[0] for t in latents) multi_res = False # Not using multi_res path in bucket mode - logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") for i, lat in enumerate(latents): - logging.info(f" Bucket {i}: shape {lat.shape}") + logging.debug(f" Bucket {i}: shape {lat.shape}") return latents, num_images, multi_res # Non-bucket mode @@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") + logging.debug(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: @@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode): if bucket_mode: return positive # Skip validation in bucket mode - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: return positive * num_images elif len(positive) != num_images: @@ -596,6 +602,8 @@ def _create_weight_adapter( shape = module.weight.shape lora_params = {} + logging.debug(f"Creating weight adapter for {key} with shape {shape}") + if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get(f"{key}.dora_scale", None) @@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): return lora_sd, all_weight_adapters +def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup LoRA adapters in bypass mode. + + In bypass mode: + - Weight adapters (lora/lokr/oft) use bypass injection (forward hook) + - Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification) + + This is useful when the base model weights are quantized and cannot be + directly modified. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list, bypass_manager) + """ + lora_sd = {} + all_weight_adapters = [] + bypass_manager = BypassInjectionManager() + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + all_weight_adapters.append(adapter) + + key = f"{n}.weight" + # BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass + # Only use bypass for adapters that have h() method (lora/lokr/oft) + if isinstance(adapter, BiasDiff): + mp.add_weight_wrapper(key, adapter) + logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}") + else: + bypass_manager.add_adapter(key, adapter, strength=1.0) + logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}") + + if hasattr(m, "bias") and m.bias is not None: + # Bias adapters still use weight wrapper (bias is usually not quantized) + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}") + + return lora_sd, all_weight_adapters, bypass_manager + + def _create_optimizer(optimizer_name, parameters, learning_rate): """Create optimizer based on name. @@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", ), + io.Boolean.Input( + "bypass_mode", + default=False, + tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", + ), ], outputs=[ - io.Model.Output( - display_name="model", tooltip="Model with LoRA applied" - ), io.Custom("LORA_MODEL").Output( display_name="lora", tooltip="LoRA weights" ), @@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing, existing_lora, bucket_mode, + bypass_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] + bypass_mode = bypass_mode[0] # Process latents based on mode if bucket_mode: @@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode): existing_weights, existing_steps = _load_existing_lora(existing_lora) # Setup LoRA adapters - lora_sd, all_weight_adapters = _setup_lora_adapters( - mp, existing_weights, algorithm, lora_dtype, rank - ) + bypass_manager = None + if bypass_mode: + logging.debug("Using bypass mode for training") + lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass( + mp, existing_weights, algorithm, lora_dtype, rank + ) + else: + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) # Create optimizer and loss function optimizer = _create_optimizer( @@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode): guider = TrainGuider(mp) guider.set_conds(positive) + # Inject bypass hooks if bypass mode is enabled + bypass_injections = None + if bypass_manager is not None: + bypass_injections = bypass_manager.create_injections(mp.model) + for injection in bypass_injections: + injection.inject(mp) + logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks") + # Run training loop try: _run_training_loop( @@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + # Eject bypass hooks if they were injected + if bypass_injections is not None: + for injection in bypass_injections: + injection.eject(mp) + logging.debug("[BypassMode] Ejected bypass hooks") for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer @@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode): for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) + # mp in train node is highly specialized for training + # use it in inference will result in bad behavior so we don't return it + return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) class LoraModelLoader(io.ComfyNode):# diff --git a/nodes.py b/nodes.py index b75247665..8a8df9246 100644 --- a/nodes.py +++ b/nodes.py @@ -722,6 +722,69 @@ class LoraLoaderModelOnly(LoraLoader): def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @@ -2067,6 +2130,8 @@ NODE_CLASS_MAPPINGS = { "LatentFlip": LatentFlip, "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, "CLIPLoader": CLIPLoader, "UNETLoader": UNETLoader, "DualCLIPLoader": DualCLIPLoader, @@ -2106,6 +2171,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA", + "LoraLoaderBypass": "Load LoRA (Bypass)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)", From 26c5bbb8751071cb499b65d48e218b54e856572d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 24 Jan 2026 20:02:32 -0800 Subject: [PATCH 53/58] Move nodes from previous PR into their own file. (#12066) --- comfy_extras/nodes_lora_debug.py | 79 ++++++++++++++++++++++++++++++++ nodes.py | 68 +-------------------------- 2 files changed, 80 insertions(+), 67 deletions(-) create mode 100644 comfy_extras/nodes_lora_debug.py diff --git a/comfy_extras/nodes_lora_debug.py b/comfy_extras/nodes_lora_debug.py new file mode 100644 index 000000000..937a0fbfb --- /dev/null +++ b/comfy_extras/nodes_lora_debug.py @@ -0,0 +1,79 @@ +import folder_paths +import comfy.utils +import comfy.sd + + +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + EXPERIMENTAL = True + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + + +NODE_CLASS_MAPPINGS = { + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)", +} diff --git a/nodes.py b/nodes.py index 8a8df9246..2535b4ec6 100644 --- a/nodes.py +++ b/nodes.py @@ -722,69 +722,6 @@ class LoraLoaderModelOnly(LoraLoader): def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) -class LoraLoaderBypass: - """ - Apply LoRA in bypass mode without modifying base model weights. - - Bypass mode computes: output = base_forward(x) + lora_path(x) - This is useful for training and when model weights are offloaded. - """ - - def __init__(self): - self.loaded_lora = None - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), - "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL", "CLIP") - OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") - FUNCTION = "load_lora" - - CATEGORY = "loaders" - DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." - - def load_lora(self, model, clip, lora_name, strength_model, strength_clip): - if strength_model == 0 and strength_clip == 0: - return (model, clip) - - lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) - lora = None - if self.loaded_lora is not None: - if self.loaded_lora[0] == lora_path: - lora = self.loaded_lora[1] - else: - self.loaded_lora = None - - if lora is None: - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - self.loaded_lora = (lora_path, lora) - - model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) - return (model_lora, clip_lora) - - -class LoraLoaderBypassModelOnly(LoraLoaderBypass): - @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "lora_name": (folder_paths.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "load_lora_model_only" - - def load_lora_model_only(self, model, lora_name, strength_model): - return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) - class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @@ -2130,8 +2067,6 @@ NODE_CLASS_MAPPINGS = { "LatentFlip": LatentFlip, "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, - "LoraLoaderBypass": LoraLoaderBypass, - "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, "CLIPLoader": CLIPLoader, "UNETLoader": UNETLoader, "DualCLIPLoader": DualCLIPLoader, @@ -2171,8 +2106,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA", - "LoraLoaderBypass": "Load LoRA (Bypass)", - "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)", @@ -2498,6 +2431,7 @@ async def init_builtin_extra_nodes(): "nodes_wanmove.py", "nodes_image_compare.py", "nodes_zimage.py", + "nodes_lora_debug.py" ] import_failed = [] From 7ee77ff038937bdfdbea5d603ad8d4c487c14fd6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 25 Jan 2026 18:01:55 -0800 Subject: [PATCH 54/58] Add name to LoraLoaderModelOnly. (#12078) --- nodes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 2535b4ec6..ad474d3cd 100644 --- a/nodes.py +++ b/nodes.py @@ -2105,7 +2105,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", - "LoraLoader": "Load LoRA", + "LoraLoader": "Load LoRA (Model and CLIP)", + "LoraLoaderModelOnly": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)", From 2129e7d27854057737808438ec5b9db195bb81bb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 26 Jan 2026 08:39:00 -0800 Subject: [PATCH 55/58] Fix mistral 3 tokenizer code failing on latest transformers version and other breakage. (#12095) * Fix mistral 3 tokenizer code failing on latest transformers version. * Add requests to the requirements --- comfy/sd1_clip.py | 15 +++++++++++---- comfy/text_encoders/flux.py | 2 +- requirements.txt | 1 + 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..d4f22120b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, start_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -479,8 +479,15 @@ class SDTokenizer: empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] + if len(empty) > 0: + self.tokens_start = 1 + self.start_token = empty[0] + else: + self.tokens_start = 0 + self.start_token = start_token + if start_token is None: + logging.warning("WARNING: There's something wrong with your tokenizers.'") + if end_token is not None: self.end_token = end_token else: @@ -488,7 +495,7 @@ class SDTokenizer: self.end_token = empty[1] else: self.tokens_start = 0 - self.start_token = None + self.start_token = start_token if end_token is not None: self.end_token = end_token else: diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 4075afca4..f67a5f805 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} diff --git a/requirements.txt b/requirements.txt index ec89dccd2..8d38c114b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 +requests #non essential dependencies: kornia>=0.7.1 From bfe31d0b9dce106cbb3bc073660d4ab1d7d9e992 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Mon, 26 Jan 2026 22:33:19 +0200 Subject: [PATCH 56/58] IC-LoRA: support small grid (#12074) --- comfy_extras/nodes_lt.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index b91a22309..2aec62f61 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -223,11 +223,24 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords[:, 0] += frame_idx + + # The following adjusts keyframe end positions for small grid IC-LoRA. + # After dilation, the small grid has the same size and position as the large grid, + # but each token encodes a larger image patch. We adjust the end position (not start) + # so that RoPE represents the correct middle point of each token. + # keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end]) + # We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3. + spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor( + scale_factors[1:], + device=pixel_coords.device, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype) + if keyframe_idxs is None: keyframe_idxs = pixel_coords else: @@ -235,12 +248,12 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128): + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) if guide_mask is not None: target_h = max(noise_mask.shape[3], guide_mask.shape[3]) From cd4985e2f33e6c339a6f176b2caed155309a1c6f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 26 Jan 2026 23:58:33 +0200 Subject: [PATCH 57/58] chore(api-nodes): remove ByteDanceImageEditNode node (seededit) (#12069) Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/apis/bytedance.py | 11 ---- comfy_api_nodes/nodes_bytedance.py | 95 ------------------------------ 2 files changed, 106 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 400648cca..23cbe2372 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -13,17 +13,6 @@ class Text2ImageTaskCreationRequest(BaseModel): watermark: bool | None = Field(False) -class Image2ImageTaskCreationRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - response_format: str | None = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: str | None = Field("adaptive") - seed: int | None = Field(..., ge=0, le=2147483647) - guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(False) - - class Seedream4Options(BaseModel): max_images: int = Field(15) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 486801150..0cb5e3be8 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -9,7 +9,6 @@ from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, - Image2ImageTaskCreationRequest, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedream4Options, @@ -174,99 +173,6 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) -class ByteDanceImageEditNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ByteDanceImageEditNode", - display_name="ByteDance Image Edit", - category="api node/image/ByteDance", - description="Edit images using ByteDance models via api based on prompt", - inputs=[ - IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), - IO.Image.Input( - "image", - tooltip="The base image to edit", - ), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Instruction to edit image", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to use for generation", - optional=True, - ), - IO.Float.Input( - "guidance_scale", - default=5.5, - min=1.0, - max=10.0, - step=0.01, - display_mode=IO.NumberDisplay.number, - tooltip="Higher value makes the image follow the prompt more closely", - optional=True, - ), - IO.Boolean.Input( - "watermark", - default=False, - tooltip='Whether to add an "AI generated" watermark to the image', - optional=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - prompt: str, - seed: int, - guidance_scale: float, - watermark: bool, - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1)) - source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] - payload = Image2ImageTaskCreationRequest( - model=model, - prompt=prompt, - image=source_url, - seed=seed, - guidance_scale=guidance_scale, - watermark=watermark, - ) - response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), - data=payload, - response_model=ImageTaskCreationResponse, - ) - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - - class ByteDanceSeedreamNode(IO.ComfyNode): @classmethod @@ -1101,7 +1007,6 @@ class ByteDanceExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ ByteDanceImageNode, - ByteDanceImageEditNode, ByteDanceSeedreamNode, ByteDanceTextToVideoNode, ByteDanceImageToVideoNode, From 29011ba87eb2131c7943bf0eaf9ac8c0a6ff3c7f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 27 Jan 2026 00:10:09 +0200 Subject: [PATCH 58/58] [API Nodes] add Magnific nodes (#11986) * feat(api-nodes): add Magnific nodes * aggressive downscaling should not be performed * disable upscaler nodes --------- Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/apis/magnific.py | 122 ++++ comfy_api_nodes/nodes_magnific.py | 889 +++++++++++++++++++++++++ comfy_api_nodes/util/conversions.py | 18 +- comfy_api_nodes/util/upload_helpers.py | 2 +- 4 files changed, 1021 insertions(+), 10 deletions(-) create mode 100644 comfy_api_nodes/apis/magnific.py create mode 100644 comfy_api_nodes/nodes_magnific.py diff --git a/comfy_api_nodes/apis/magnific.py b/comfy_api_nodes/apis/magnific.py new file mode 100644 index 000000000..b9f148def --- /dev/null +++ b/comfy_api_nodes/apis/magnific.py @@ -0,0 +1,122 @@ +from typing import TypedDict + +from pydantic import AliasChoices, BaseModel, Field, model_validator + + +class InputPortraitMode(TypedDict): + portrait_mode: str + portrait_style: str + portrait_beautifier: str + + +class InputAdvancedSettings(TypedDict): + advanced_settings: str + whites: int + blacks: int + brightness: int + contrast: int + saturation: int + engine: str + transfer_light_a: str + transfer_light_b: str + fixed_generation: bool + + +class InputSkinEnhancerMode(TypedDict): + mode: str + skin_detail: int + optimized_for: str + + +class ImageUpscalerCreativeRequest(BaseModel): + image: str = Field(...) + scale_factor: str = Field(...) + optimized_for: str = Field(...) + prompt: str | None = Field(None) + creativity: int = Field(...) + hdr: int = Field(...) + resemblance: int = Field(...) + fractality: int = Field(...) + engine: str = Field(...) + + +class ImageUpscalerPrecisionV2Request(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + ultra_detail: int = Field(...) + flavor: str = Field(...) + scale_factor: int = Field(...) + + +class ImageRelightAdvancedSettingsRequest(BaseModel): + whites: int = Field(...) + blacks: int = Field(...) + brightness: int = Field(...) + contrast: int = Field(...) + saturation: int = Field(...) + engine: str = Field(...) + transfer_light_a: str = Field(...) + transfer_light_b: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageRelightRequest(BaseModel): + image: str = Field(...) + prompt: str | None = Field(None) + transfer_light_from_reference_image: str | None = Field(None) + light_transfer_strength: int = Field(...) + interpolate_from_original: bool = Field(...) + change_background: bool = Field(...) + style: str = Field(...) + preserve_details: bool = Field(...) + advanced_settings: ImageRelightAdvancedSettingsRequest | None = Field(...) + + +class ImageStyleTransferRequest(BaseModel): + image: str = Field(...) + reference_image: str = Field(...) + prompt: str | None = Field(None) + style_strength: int = Field(...) + structure_strength: int = Field(...) + is_portrait: bool = Field(...) + portrait_style: str | None = Field(...) + portrait_beautifier: str | None = Field(...) + flavor: str = Field(...) + engine: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageSkinEnhancerCreativeRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + + +class ImageSkinEnhancerFaithfulRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + skin_detail: int = Field(...) + + +class ImageSkinEnhancerFlexibleRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + optimized_for: str = Field(...) + + +class TaskResponse(BaseModel): + """Unified response model that handles both wrapped and unwrapped API responses.""" + + task_id: str = Field(...) + status: str = Field(validation_alias=AliasChoices("status", "task_status")) + generated: list[str] | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "data" in values and isinstance(values["data"], dict): + return values["data"] + return values diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py new file mode 100644 index 000000000..013e71cc8 --- /dev/null +++ b/comfy_api_nodes/nodes_magnific.py @@ -0,0 +1,889 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.magnific import ( + ImageRelightAdvancedSettingsRequest, + ImageRelightRequest, + ImageSkinEnhancerCreativeRequest, + ImageSkinEnhancerFaithfulRequest, + ImageSkinEnhancerFlexibleRequest, + ImageStyleTransferRequest, + ImageUpscalerCreativeRequest, + ImageUpscalerPrecisionV2Request, + InputAdvancedSettings, + InputPortraitMode, + InputSkinEnhancerMode, + TaskResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + downscale_image_tensor, + get_image_dimensions, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, +) + + +class MagnificImageUpscalerCreativeNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerCreativeNode", + display_name="Magnific Image Upscale (Creative)", + category="api node/image/Magnific", + description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " + "Maximum output: 25.3 megapixels.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "optimized_for", + options=[ + "standard", + "soft_portraits", + "hard_portraits", + "art_n_illustration", + "videogame_assets", + "nature_n_landscapes", + "films_n_photography", + "3d_renders", + "science_fiction_n_horror", + ], + ), + IO.Int.Input("creativity", min=-10, max=10, default=0, display_mode=IO.NumberDisplay.slider), + IO.Int.Input( + "hdr", + min=-10, + max=10, + default=0, + tooltip="The level of definition and detail.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "resemblance", + min=-10, + max=10, + default=0, + tooltip="The level of resemblance to the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "fractality", + min=-10, + max=10, + default=0, + tooltip="The strength of the prompt and intricacy per square pixel.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=["automatic", "magnific_illusio", "magnific_sharpy", "magnific_sparkle"], + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum pixel limit.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; + {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + scale_factor: str, + optimized_for: str, + creativity: int, + hdr: int, + resemblance: int, + fractality: int, + engine: str, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_pixels = 25_300_000 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.rstrip("x")) + output_pixels = height * width * requested_scale * requested_scale + + if output_pixels > max_output_pixels: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + input_pixels = width * height + scale = 2 + max_input_pixels = max_output_pixels // 4 + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= max_output_pixels: + scale = candidate + max_input_pixels = None + break + downscale_ratio = math.sqrt(scale_output_pixels / max_output_pixels) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = max_output_pixels // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + scale_factor = f"{scale}x" + else: + raise ValueError( + f"Output size ({width * requested_scale}x{height * requested_scale} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {max_output_pixels:,} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerCreativeRequest( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=scale_factor, + optimized_for=optimized_for, + creativity=creativity, + hdr=hdr, + resemblance=resemblance, + fractality=fractality, + engine=engine, + prompt=prompt if prompt else None, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerPreciseV2Node", + display_name="Magnific Image Upscale (Precise V2)", + category="api node/image/Magnific", + description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " + "Maximum output: 10060×10060 pixels.", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "flavor", + options=["sublime", "photo", "photo_denoiser"], + tooltip="Processing style: " + "sublime for general use, photo for photographs, photo_denoiser for noisy photos.", + ), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=7, + tooltip="Image sharpness intensity. Higher values increase edge definition and clarity.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=7, + tooltip="Intelligent grain/texture enhancement to prevent the image from " + "looking too smooth or artificial.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "ultra_detail", + min=0, + max=100, + default=30, + tooltip="Controls fine detail, textures, and micro-details added during upscaling.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum resolution.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; + {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + scale_factor: str, + flavor: str, + sharpen: int, + smart_grain: int, + ultra_detail: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_dimension = 10060 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.strip("x")) + output_width = width * requested_scale + output_height = height * requested_scale + + if output_width > max_output_dimension or output_height > max_output_dimension: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + max_dim = max(width, height) + scale = 2 + max_input_dim = max_output_dimension // 2 + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + output_dim = max_dim * candidate + if output_dim <= max_output_dimension: + scale = candidate + max_input_pixels = None + break + downscale_ratio = output_dim / max_output_dimension + if downscale_ratio <= 2.0: + scale = candidate + max_input_dim = max_output_dimension // candidate + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + requested_scale = scale + else: + raise ValueError( + f"Output dimensions ({output_width}x{output_height}) exceed maximum allowed " + f"resolution of {max_output_dimension}x{max_output_dimension} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerPrecisionV2Request( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=requested_scale, + flavor=flavor, + sharpen=sharpen, + smart_grain=smart_grain, + ultra_detail=ultra_detail, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageStyleTransferNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageStyleTransferNode", + display_name="Magnific Image Style Transfer", + category="api node/image/Magnific", + description="Transfer the style from a reference image to your input image.", + inputs=[ + IO.Image.Input("image", tooltip="The image to apply style transfer to."), + IO.Image.Input("reference_image", tooltip="The reference image to extract style from."), + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "style_strength", + min=0, + max=100, + default=100, + tooltip="Percentage of style strength.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "structure_strength", + min=0, + max=100, + default=50, + tooltip="Maintains the structure of the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "flavor", + options=["faithful", "gen_z", "psychedelia", "detaily", "clear", "donotstyle", "donotstyle_sharp"], + tooltip="Style transfer flavor.", + ), + IO.Combo.Input( + "engine", + options=[ + "balanced", + "definio", + "illusio", + "3d_cartoon", + "colorful_anime", + "caricature", + "real", + "super_real", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.DynamicCombo.Input( + "portrait_mode", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Combo.Input( + "portrait_style", + options=["standard", "pop", "super_pop"], + tooltip="Visual style applied to portrait images.", + ), + IO.Combo.Input( + "portrait_beautifier", + options=["none", "beautify_face", "beautify_face_max"], + tooltip="Facial beautification intensity on portraits.", + ), + ], + ), + ], + tooltip="Enable portrait mode for facial enhancements.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="When disabled, expect each generation to introduce a degree of randomness, " + "leading to more diverse outcomes.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + reference_image: Input.Image, + prompt: str, + style_strength: int, + structure_strength: int, + flavor: str, + engine: str, + portrait_mode: InputPortraitMode, + fixed_generation: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + is_portrait = portrait_mode["portrait_mode"] == "enabled" + portrait_style = portrait_mode.get("portrait_style", "standard") + portrait_beautifier = portrait_mode.get("portrait_beautifier", "none") + + uploaded_urls = await upload_images_to_comfyapi(cls, [image, reference_image], max_images=2) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-style-transfer", method="POST"), + response_model=TaskResponse, + data=ImageStyleTransferRequest( + image=uploaded_urls[0], + reference_image=uploaded_urls[1], + prompt=prompt if prompt else None, + style_strength=style_strength, + structure_strength=structure_strength, + is_portrait=is_portrait, + portrait_style=portrait_style if is_portrait else None, + portrait_beautifier=portrait_beautifier if is_portrait and portrait_beautifier != "none" else None, + flavor=flavor, + engine=engine, + fixed_generation=fixed_generation, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-style-transfer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageRelightNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageRelightNode", + display_name="Magnific Image Relight", + category="api node/image/Magnific", + description="Relight an image with lighting adjustments and optional reference-based light transfer.", + inputs=[ + IO.Image.Input("image", tooltip="The image to relight."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Descriptive guidance for lighting. Supports emphasis notation (1-1.4).", + ), + IO.Int.Input( + "light_transfer_strength", + min=0, + max=100, + default=100, + tooltip="Intensity of light transfer application.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "style", + options=[ + "standard", + "darker_but_realistic", + "clean", + "smooth", + "brighter", + "contrasted_n_hdr", + "just_composition", + ], + tooltip="Stylistic output preference.", + ), + IO.Boolean.Input( + "interpolate_from_original", + default=False, + tooltip="Restricts generation freedom to match original more closely.", + ), + IO.Boolean.Input( + "change_background", + default=True, + tooltip="Modifies background based on prompt/reference.", + ), + IO.Boolean.Input( + "preserve_details", + default=True, + tooltip="Maintains texture and fine details from original.", + ), + IO.DynamicCombo.Input( + "advanced_settings", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "whites", + min=0, + max=100, + default=50, + tooltip="Adjusts the brightest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "blacks", + min=0, + max=100, + default=50, + tooltip="Adjusts the darkest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "brightness", + min=0, + max=100, + default=50, + tooltip="Overall brightness adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "contrast", + min=0, + max=100, + default=50, + tooltip="Contrast adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "saturation", + min=0, + max=100, + default=50, + tooltip="Color saturation adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=[ + "automatic", + "balanced", + "cool", + "real", + "illusio", + "fairy", + "colorful_anime", + "hard_transform", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.Combo.Input( + "transfer_light_a", + options=["automatic", "low", "medium", "normal", "high", "high_on_faces"], + tooltip="The intensity of light transfer.", + ), + IO.Combo.Input( + "transfer_light_b", + options=[ + "automatic", + "composition", + "straight", + "smooth_in", + "smooth_out", + "smooth_both", + "reverse_both", + "soft_in", + "soft_out", + "soft_mid", + # "strong_mid", # Commented out because requests fail when this is set. + "style_shift", + "strong_shift", + ], + tooltip="Also modifies light transfer intensity. " + "Can be combined with the previous control for varied effects.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="Ensures consistent output with the same settings.", + ), + ], + ), + ], + tooltip="Fine-tuning options for advanced lighting control.", + ), + IO.Image.Input( + "reference_image", + optional=True, + tooltip="Optional reference image to transfer lighting from.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + light_transfer_strength: int, + style: str, + interpolate_from_original: bool, + change_background: bool, + preserve_details: bool, + advanced_settings: InputAdvancedSettings, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if reference_image is not None and get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + if reference_image is not None: + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + reference_url = None + if reference_image is not None: + reference_url = (await upload_images_to_comfyapi(cls, reference_image, max_images=1))[0] + + adv_settings = None + if advanced_settings["advanced_settings"] == "enabled": + adv_settings = ImageRelightAdvancedSettingsRequest( + whites=advanced_settings["whites"], + blacks=advanced_settings["blacks"], + brightness=advanced_settings["brightness"], + contrast=advanced_settings["contrast"], + saturation=advanced_settings["saturation"], + engine=advanced_settings["engine"], + transfer_light_a=advanced_settings["transfer_light_a"], + transfer_light_b=advanced_settings["transfer_light_b"], + fixed_generation=advanced_settings["fixed_generation"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-relight", method="POST"), + response_model=TaskResponse, + data=ImageRelightRequest( + image=image_url, + prompt=prompt if prompt else None, + transfer_light_from_reference_image=reference_url, + light_transfer_strength=light_transfer_strength, + interpolate_from_original=interpolate_from_original, + change_background=change_background, + style=style, + preserve_details=preserve_details, + advanced_settings=adv_settings, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-relight/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageSkinEnhancerNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageSkinEnhancerNode", + display_name="Magnific Image Skin Enhancer", + category="api node/image/Magnific", + description="Skin enhancement for portraits with multiple processing modes.", + inputs=[ + IO.Image.Input("image", tooltip="The portrait image to enhance."), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=0, + tooltip="Sharpening intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=2, + tooltip="Smart grain intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.DynamicCombo.Input( + "mode", + options=[ + IO.DynamicCombo.Option("creative", []), + IO.DynamicCombo.Option( + "faithful", + [ + IO.Int.Input( + "skin_detail", + min=0, + max=100, + default=80, + tooltip="Skin detail enhancement level.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + IO.DynamicCombo.Option( + "flexible", + [ + IO.Combo.Input( + "optimized_for", + options=[ + "enhance_skin", + "improve_lighting", + "enhance_everything", + "transform_to_real", + "no_make_up", + ], + tooltip="Enhancement optimization target.", + ), + ], + ), + ], + tooltip="Processing mode: creative for artistic enhancement, " + "faithful for preserving original appearance, " + "flexible for targeted optimization.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $rates := {"creative": 0.29, "faithful": 0.37, "flexible": 0.45}; + {"type":"usd","usd": $lookup($rates, widgets.mode)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + sharpen: int, + smart_grain: int, + mode: InputSkinEnhancerMode, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=4096 * 4096))[0] + selected_mode = mode["mode"] + + if selected_mode == "creative": + endpoint = "creative" + data = ImageSkinEnhancerCreativeRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + ) + elif selected_mode == "faithful": + endpoint = "faithful" + data = ImageSkinEnhancerFaithfulRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + skin_detail=mode["skin_detail"], + ) + else: # flexible + endpoint = "flexible" + data = ImageSkinEnhancerFlexibleRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + optimized_for=mode["optimized_for"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{endpoint}", method="POST"), + response_model=TaskResponse, + data=data, + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + # MagnificImageUpscalerCreativeNode, + # MagnificImageUpscalerPreciseV2Node, + MagnificImageStyleTransferNode, + MagnificImageRelightNode, + MagnificImageSkinEnhancerNode, + ] + + +async def comfy_entrypoint() -> MagnificExtension: + return MagnificExtension() diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 0e15a0efe..3e37e8a8c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -56,15 +56,14 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, *, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: """Converts a torch.Tensor image to a named BytesIO object. Args: image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -79,13 +78,14 @@ def tensor_to_bytesio( return img_binary -def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: +def tensor_to_pil(image: torch.Tensor, total_pixels: int | None = 2048 * 2048) -> Image.Image: """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" if len(image.shape) > 3: image = image[0] # TODO: remove alpha if not allowed and present input_tensor = image.cpu() - input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + if total_pixels is not None: + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() image_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) return img @@ -93,14 +93,14 @@ def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image def tensor_to_base64_string( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -161,14 +161,14 @@ def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) - def tensor_to_data_uri( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Converts a tensor image to a Data URI string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). Returns: diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 2190f9639..3153f2b98 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -49,7 +49,7 @@ async def upload_images_to_comfyapi( mime_type: str | None = None, wait_label: str | None = "Uploading", show_batch_index: bool = True, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs.