From 6c5f906bf2443c488c409b6998e9da4bb0afe55a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 28 Jan 2026 19:46:57 +0200 Subject: [PATCH 001/215] feat(api-nodes): add Grok Imagine nodes (#12136) --- comfy_api_nodes/apis/grok.py | 67 ++++++ comfy_api_nodes/nodes_grok.py | 417 ++++++++++++++++++++++++++++++++++ 2 files changed, 484 insertions(+) create mode 100644 comfy_api_nodes/apis/grok.py create mode 100644 comfy_api_nodes/nodes_grok.py diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py new file mode 100644 index 000000000..8e3c79ab9 --- /dev/null +++ b/comfy_api_nodes/apis/grok.py @@ -0,0 +1,67 @@ +from pydantic import BaseModel, Field + + +class ImageGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + aspect_ratio: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_for: str = Field("url") + + +class InputUrlObject(BaseModel): + url: str = Field(...) + + +class ImageEditRequest(BaseModel): + model: str = Field(...) + image: InputUrlObject = Field(...) + prompt: str = Field(...) + resolution: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_for: str = Field("url") + + +class VideoGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + image: InputUrlObject | None = Field(...) + duration: int = Field(...) + aspect_ratio: str | None = Field(...) + resolution: str = Field(...) + seed: int = Field(...) + + +class VideoEditRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + video: InputUrlObject = Field(...) + seed: int = Field(...) + + +class ImageResponseObject(BaseModel): + url: str | None = Field(None) + b64_json: str | None = Field(None) + revised_prompt: str | None = Field(None) + + +class ImageGenerationResponse(BaseModel): + data: list[ImageResponseObject] = Field(...) + + +class VideoGenerationResponse(BaseModel): + request_id: str = Field(...) + + +class VideoResponseObject(BaseModel): + url: str = Field(...) + upsampled_prompt: str | None = Field(None) + duration: int = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str | None = Field(None) + video: VideoResponseObject | None = Field(None) + model: str | None = Field(None) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py new file mode 100644 index 000000000..da15e97ea --- /dev/null +++ b/comfy_api_nodes/nodes_grok.py @@ -0,0 +1,417 @@ +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.grok import ( + ImageEditRequest, + ImageGenerationRequest, + ImageGenerationResponse, + InputUrlObject, + VideoEditRequest, + VideoGenerationRequest, + VideoGenerationResponse, + VideoStatusResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_video_to_comfyapi, + validate_string, + validate_video_duration, +) + + +class GrokImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageNode", + display_name="Grok Image", + category="api node/image/Grok", + description="Generate images using Grok based on a text prompt", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + ), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), + expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + number_of_images: int, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"), + data=ImageGenerationRequest( + model=model, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=number_of_images, + seed=seed, + ), + response_model=ImageGenerationResponse, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageEditNode", + display_name="Grok Image Edit", + category="api node/image/Grok", + description="Modify an existing image based on a text prompt", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input("resolution", options=["1K"]), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of edited images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), + expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + resolution: str, + number_of_images: int, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), + data=ImageEditRequest( + model=model, + image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + prompt=prompt, + resolution=resolution.lower(), + n=number_of_images, + seed=seed, + ), + response_model=ImageGenerationResponse, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoNode", + display_name="Grok Video", + category="api node/video/Grok", + description="Generate video from a prompt or an image", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=1, + max=15, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Image.Input("image", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + expr=""" + ( + $base := 0.181 * widgets.duration; + {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + image: Input.Image | None = None, + ) -> IO.NodeOutput: + image_url = None + if image is not None: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}") + validate_string(prompt, strip_whitespace=True, min_length=1) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model, + image=image_url, + prompt=prompt, + resolution=resolution, + duration=duration, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoEditNode", + display_name="Grok Video Edit", + category="api node/video/Grok", + description="Edit an existing video based on a text prompt.", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + video: Input.Video, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=1, max_duration=8.7) + video_stream = video.get_stream_source() + video_size = get_fs_object_size(video_stream) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"), + data=VideoEditRequest( + model=model, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + prompt=prompt, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GrokImageNode, + GrokImageEditNode, + GrokVideoNode, + GrokVideoEditNode, + ] + + +async def comfy_entrypoint() -> GrokExtension: + return GrokExtension() From d9b856754741a680da34095c6e8afa7792ccede3 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 29 Jan 2026 02:47:37 +0900 Subject: [PATCH 002/215] bump manager version to 4.1b1 (#12140) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index bea6d4927..c420cc48e 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.5 +comfyui_manager==4.1b1 From 1711020904edd33bad7556bb70ef3ec15d4f8e5a Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 29 Jan 2026 01:48:02 +0800 Subject: [PATCH 003/215] chore: update workflow templates to v0.8.27 (#12141) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 666a0e35b..4ac94cb16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.37.11 -comfyui-workflow-templates==0.8.24 +comfyui-workflow-templates==0.8.27 comfyui-embedded-docs==0.4.0 torch torchsde From c9b633d84f1dd9c2e6fb5dbc733a0687e5a9764b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 28 Jan 2026 17:52:51 -0800 Subject: [PATCH 004/215] Add missing spacial downscale ratios. (#12146) --- comfy/latent_formats.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 38f18a83f..4b3a3798c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -81,6 +81,7 @@ class SD_X4(LatentFormat): class SC_Prior(LatentFormat): latent_channels = 16 + spacial_downscale_ratio = 42 def __init__(self): self.scale_factor = 1.0 self.latent_rgb_factors = [ @@ -103,6 +104,7 @@ class SC_Prior(LatentFormat): ] class SC_B(LatentFormat): + spacial_downscale_ratio = 4 def __init__(self): self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ @@ -274,6 +276,7 @@ class Mochi(LatentFormat): class LTXV(LatentFormat): latent_channels = 128 latent_dimensions = 3 + spacial_downscale_ratio = 32 def __init__(self): self.latent_rgb_factors = [ @@ -517,6 +520,7 @@ class Wan21(LatentFormat): class Wan22(Wan21): latent_channels = 48 latent_dimensions = 3 + spacial_downscale_ratio = 16 latent_rgb_factors = [ [ 0.0119, 0.0103, 0.0046], From b0d9708974f50fce7d2448ac84e9260c87f7ade3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Jan 2026 00:27:23 -0500 Subject: [PATCH 005/215] ComfyUI v0.11.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index d56466db2..b1ebaa115 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.11.0" +__version__ = "0.11.1" diff --git a/pyproject.toml b/pyproject.toml index c0e787abd..042f124e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.11.0" +version = "0.11.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 3aace5c8dc81d7bf1c3020edcbc4e6d37caa0a29 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 29 Jan 2026 17:10:08 -0800 Subject: [PATCH 006/215] fix: count non-dict items in outputs_count (#12166) Move count increment before isinstance(item, dict) check so that non-dict output items (like text strings from PreviewAny node) are included in outputs_count. This aligns OSS Python with Cloud's Go implementation which uses len(itemsArray) to count ALL items regardless of type. Amp-Thread-ID: https://ampcode.com/threads/T-019c0bb5-14e0-744f-8808-1e57653f3ae3 Co-authored-by: Amp --- comfy_execution/jobs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 97fd803b8..bf091a448 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -171,9 +171,10 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: + count += 1 + if not isinstance(item, dict): continue - count += 1 if preview_output is None and is_previewable(media_type, item): enriched = { From bbe2c13a7075bcf4de3b6744f96d84d12c334350 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 29 Jan 2026 20:52:22 -0800 Subject: [PATCH 007/215] Make empty hunyuan latent 1.0 work with the 1.5 model. (#12171) --- comfy_extras/nodes_hunyuan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index ceff657d3..774da75a3 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -56,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): @classmethod def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove @@ -73,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: # Using scale factor of 16 instead of 8 latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples": latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16}) class HunyuanVideo15ImageToVideo(io.ComfyNode): From 0a7993729c70933e6fe19e26c2d7f60f6ebaf46a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 30 Jan 2026 10:21:48 -0800 Subject: [PATCH 008/215] Remove NodeInfoV3-related code; we are almost 100% guaranteed to stick with NodeInfoV1 for the foreseable future (#12147) Co-authored-by: guill --- comfy_api/latest/_io.py | 63 ----------------------------------------- 1 file changed, 63 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index be759952e..f5445554f 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1252,23 +1252,6 @@ class NodeInfoV1: price_badge: dict | None = None search_aliases: list[str]=None -@dataclass -class NodeInfoV3: - input: dict=None - output: dict=None - hidden: list[str]=None - name: str=None - display_name: str=None - description: str=None - python_module: Any = None - category: str=None - output_node: bool=None - deprecated: bool=None - experimental: bool=None - dev_only: bool=None - api_node: bool=None - price_badge: dict | None = None - @dataclass class PriceBadgeDepends: @@ -1497,40 +1480,6 @@ class Schema: ) return info - - def get_v3_info(self, cls) -> NodeInfoV3: - input_dict = {} - output_dict = {} - hidden_list = [] - # TODO: make sure dynamic types will be handled correctly - if self.inputs: - for input in self.inputs: - add_to_dict_v3(input, input_dict) - if self.outputs: - for output in self.outputs: - add_to_dict_v3(output, output_dict) - if self.hidden: - for hidden in self.hidden: - hidden_list.append(hidden.value) - - info = NodeInfoV3( - input=input_dict, - output=output_dict, - hidden=hidden_list, - name=self.node_id, - display_name=self.display_name, - description=self.description, - category=self.category, - output_node=self.is_output_node, - deprecated=self.is_deprecated, - experimental=self.is_experimental, - dev_only=self.is_dev_only, - api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), - price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, - ) - return info - def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]: out_dict = { "required": {}, @@ -1585,9 +1534,6 @@ def add_to_dict_v1(i: Input, d: dict): as_dict.pop("optional", None) d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) -def add_to_dict_v3(io: Input | Output, d: dict): - d[io.id] = (io.get_io_type(), io.as_dict()) - class DynamicPathsDefaultValue: EMPTY_DICT = "empty_dict" @@ -1748,13 +1694,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): # set hidden type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone - - @final - @classmethod - def GET_NODE_INFO_V3(cls) -> dict[str, Any]: - schema = cls.GET_SCHEMA() - info = schema.get_v3_info(cls) - return asdict(info) ############################################# # V1 Backwards Compatibility code #-------------------------------------------- @@ -2107,12 +2046,10 @@ __all__ = [ "HiddenHolder", "Hidden", "NodeInfoV1", - "NodeInfoV3", "Schema", "ComfyNode", "NodeOutput", "add_to_dict_v1", - "add_to_dict_v3", "V3Data", "ImageCompare", "PriceBadgeDepends", From 016765378178cf2a565d4e87a99d2dabd4150d22 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 31 Jan 2026 00:04:43 +0200 Subject: [PATCH 009/215] feat(api-nodes): add RecraftCreateStyleNode node (#12055) Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/apis/recraft.py | 45 ++++++++++--------- comfy_api_nodes/nodes_recraft.py | 74 +++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 21 deletions(-) diff --git a/comfy_api_nodes/apis/recraft.py b/comfy_api_nodes/apis/recraft.py index c36d95f24..0bd7d23b3 100644 --- a/comfy_api_nodes/apis/recraft.py +++ b/comfy_api_nodes/apis/recraft.py @@ -1,11 +1,8 @@ from __future__ import annotations - - from enum import Enum -from typing import Optional -from pydantic import BaseModel, Field, conint, confloat +from pydantic import BaseModel, Field class RecraftColor: @@ -229,24 +226,24 @@ class RecraftColorObject(BaseModel): class RecraftControlsObject(BaseModel): - colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors') - background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color') - no_text: Optional[bool] = Field(None, description='Do not embed text layouts') - artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') + colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors') + background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color') + no_text: bool | None = Field(None, description='Do not embed text layouts') + artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') class RecraftImageGenerationRequest(BaseModel): prompt: str = Field(..., description='The text prompt describing the image to generate') - size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")') - n: conint(ge=1, le=6) = Field(..., description='The number of images to generate') - negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image') - model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') - style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') - substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input') - controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process') - style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID') - strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') - random_seed: Optional[int] = Field(None, description="Seed for video generation") + size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")') + n: int = Field(..., description='The number of images to generate') + negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image') + model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') + style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') + substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input') + controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process') + style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID') + strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') + random_seed: int | None = Field(None, description="Seed for video generation") # text_layout @@ -258,5 +255,13 @@ class RecraftReturnedObject(BaseModel): class RecraftImageGenerationResponse(BaseModel): created: int = Field(..., description='Unix timestamp when the generation was created') credits: int = Field(..., description='Number of credits used for the generation') - data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information') - image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image') + data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information') + image: RecraftReturnedObject | None = Field(None, description='Single generated image') + + +class RecraftCreateStyleRequest(BaseModel): + style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon") + + +class RecraftCreateStyleResponse(BaseModel): + id: str = Field(..., description="UUID of the created style") diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index c01bcaece..3a1f32263 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -12,6 +12,8 @@ from comfy_api_nodes.apis.recraft import ( RecraftColor, RecraftColorChain, RecraftControls, + RecraftCreateStyleRequest, + RecraftCreateStyleResponse, RecraftImageGenerationRequest, RecraftImageGenerationResponse, RecraftImageSize, @@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.NodeOutput(RecraftStyle(style_id=style_id)) +class RecraftCreateStyleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreateStyleNode", + display_name="Recraft Create Style", + category="api node/image/Recraft", + description="Create a custom style from reference images. " + "Upload 1-5 images to use as style references. " + "Total size of all images is limited to 5 MB.", + inputs=[ + IO.Combo.Input( + "style", + options=["realistic_image", "digital_illustration"], + tooltip="The base style of the generated images.", + ), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image", + min=1, + max=5, + ), + ), + ], + outputs=[ + IO.String.Output(display_name="style_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + style: str, + images: IO.Autogrow.Type, + ) -> IO.NodeOutput: + files = [] + total_size = 0 + max_total_size = 5 * 1024 * 1024 # 5 MB limit + for i, img in enumerate(list(images.values())): + file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read() + total_size += len(file_bytes) + if total_size > max_total_size: + raise Exception("Total size of all images exceeds 5 MB limit.") + files.append((f"file{i + 1}", file_bytes)) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"), + response_model=RecraftCreateStyleResponse, + files=files, + data=RecraftCreateStyleRequest(style=style), + content_type="multipart/form-data", + max_retries=1, + ) + + return IO.NodeOutput(response.id) + + class RecraftTextToImageNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -395,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode): negative_prompt: str = None, recraft_controls: RecraftControls = None, ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False, max_length=1000) + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -1024,6 +1095,7 @@ class RecraftExtension(ComfyExtension): RecraftStyleV3DigitalIllustrationNode, RecraftStyleV3LogoRasterNode, RecraftStyleInfiniteStyleLibrary, + RecraftCreateStyleNode, RecraftColorRGBNode, RecraftControlsNode, ] From 8aabe2403eb2a182e936d6b9514751ab075a623b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:01:33 -0800 Subject: [PATCH 010/215] Add color type and Color to RGB Int node (#12145) * add color type and color to rgb int node * review fix for allowing output --------- Co-authored-by: Jedrzej Kosinski --- comfy_api/latest/_io.py | 15 +++++++++++++ comfy_extras/nodes_color.py | 42 +++++++++++++++++++++++++++++++++++++ nodes.py | 3 ++- 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_color.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index f5445554f..78f77d4b2 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1146,6 +1146,20 @@ class ImageCompare(ComfyTypeI): def as_dict(self): return super().as_dict() + +@comfytype(io_type="COLOR") +class Color(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, advanced: bool=None, default: str="#ffffff"): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + self.default: str + + def as_dict(self): + return super().as_dict() + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2038,6 +2052,7 @@ __all__ = [ "AnyType", "MultiType", "Tracks", + "Color", # Dynamic Types "MatchType", "DynamicCombo", diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py new file mode 100644 index 000000000..80ba121cd --- /dev/null +++ b/comfy_extras/nodes_color.py @@ -0,0 +1,42 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class ColorToRGBInt(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ColorToRGBInt", + display_name="Color to RGB Int", + category="utils", + description="Convert a color to a RGB integer value.", + inputs=[ + io.Color.Input("color"), + ], + outputs=[ + io.Int.Output(display_name="rgb_int"), + ], + ) + + @classmethod + def execute( + cls, + color: str, + ) -> io.NodeOutput: + # expect format #RRGGBB + if len(color) != 7 or color[0] != "#": + raise ValueError("Color must be in format #RRGGBB") + r = int(color[1:3], 16) + g = int(color[3:5], 16) + b = int(color[5:7], 16) + return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + +class ColorExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorToRGBInt] + + +async def comfy_entrypoint() -> ColorExtension: + return ColorExtension() diff --git a/nodes.py b/nodes.py index ad474d3cd..1cb43d9e2 100644 --- a/nodes.py +++ b/nodes.py @@ -2432,7 +2432,8 @@ async def init_builtin_extra_nodes(): "nodes_wanmove.py", "nodes_image_compare.py", "nodes_zimage.py", - "nodes_lora_debug.py" + "nodes_lora_debug.py", + "nodes_color.py" ] import_failed = [] From 4064062e7d2d5abdca6767b6944e331b70065ee8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:20:06 -0800 Subject: [PATCH 011/215] Update python patch version in dep workflow. (#12184) --- .github/workflows/windows_release_dependencies.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index f61ee21a2..93e01ac93 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "11" # push: # branches: # - master From b8f848bfe38be79b9dba693bc7a214c9aefd597a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 30 Jan 2026 21:12:48 -0800 Subject: [PATCH 012/215] Fix model not working with any res. (#12186) --- comfy/ldm/cosmos/predict2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 07a4fc79f..c270e6333 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -13,6 +13,7 @@ from torchvision import transforms import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention +import comfy.ldm.common_dit def apply_rotary_pos_emb( t: torch.Tensor, @@ -835,6 +836,8 @@ class MiniTrainDIT(nn.Module): padding_mask: Optional[torch.Tensor] = None, **kwargs, ): + orig_shape = list(x.shape) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) x_B_C_T_H_W = x timesteps_B_T = timesteps crossattn_emb = context @@ -882,5 +885,5 @@ class MiniTrainDIT(nn.Module): ) x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) - x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] return x_B_C_Tt_Hp_Wp From 6e469a3f355782067b80fa072448a1b5d1322719 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 31 Jan 2026 08:44:08 +0200 Subject: [PATCH 013/215] feat(api-nodes): add Q3 models and support for Extend and MultiFrame Vidu endpoints (#12175) Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/apis/vidu.py | 24 ++ comfy_api_nodes/nodes_vidu.py | 549 +++++++++++++++++++++++++++++++++- 2 files changed, 571 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/apis/vidu.py b/comfy_api_nodes/apis/vidu.py index a9bb6f7ce..469adcdbc 100644 --- a/comfy_api_nodes/apis/vidu.py +++ b/comfy_api_nodes/apis/vidu.py @@ -6,6 +6,30 @@ class SubjectReference(BaseModel): images: list[str] = Field(...) +class FrameSetting(BaseModel): + prompt: str = Field(...) + key_image: str = Field(...) + duration: int = Field(...) + + +class TaskMultiFrameCreationRequest(BaseModel): + model: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + resolution: str = Field(...) + start_image: str = Field(...) + image_settings: list[FrameSetting] = Field(...) + + +class TaskExtendCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + resolution: str = Field(...) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + video_url: str = Field(..., description="URL of the video to extend") + + class TaskCreationRequest(BaseModel): model: str = Field(...) prompt: str = Field(..., max_length=2000) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index b9114c4bb..80de14dfe 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -2,9 +2,12 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.vidu import ( + FrameSetting, SubjectReference, TaskCreationRequest, TaskCreationResponse, + TaskExtendCreationRequest, + TaskMultiFrameCreationRequest, TaskResult, TaskStatusResponse, ) @@ -14,11 +17,14 @@ from comfy_api_nodes.util import ( get_number_of_images, poll_op, sync_op, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, validate_image_dimensions, validate_images_aspect_ratio_closeness, validate_string, + validate_video_duration, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -31,7 +37,8 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, - payload: TaskCreationRequest, + payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest, + max_poll_attempts: int = 320, ) -> list[TaskResult]: task_creation_response = await sync_op( cls, @@ -47,7 +54,7 @@ async def execute_task( response_model=TaskStatusResponse, status_extractor=lambda r: r.state, progress_extractor=lambda r: r.progress, - max_poll_attempts=320, + max_poll_attempts=max_poll_attempts, ) if not response.creations: raise RuntimeError( @@ -940,6 +947,540 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(results[0].url)) +class ViduExtendVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduExtendVideoNode", + display_name="Vidu Video Extension", + category="api node/video/Vidu", + description="Extend an existing video by generating additional frames.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq2-pro", + [ + IO.Int.Input( + "duration", + default=4, + min=1, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the extended video in seconds.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq2-turbo", + [ + IO.Int.Input( + "duration", + default=4, + min=1, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the extended video in seconds.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + ], + ), + ], + tooltip="Model to use for video extension.", + ), + IO.Video.Input( + "video", + tooltip="The source video to extend.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for the extended video (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Image.Input("end_frame", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $m := widgets.model; + $d := $lookup(widgets, "model.duration"); + $res := $lookup(widgets, "model.resolution"); + $contains($m, "pro") + ? ( + $base := $lookup({"720p": 0.15, "1080p": 0.3}, $res); + $perSec := $lookup({"720p": 0.05, "1080p": 0.075}, $res); + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : ( + $base := $lookup({"720p": 0.075, "1080p": 0.2}, $res); + $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + video: Input.Video, + prompt: str, + seed: int, + end_frame: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + validate_video_duration(video, min_duration=4, max_duration=55) + image_url = None + if end_frame is not None: + validate_image_aspect_ratio(end_frame, (1, 4), (4, 1)) + validate_image_dimensions(end_frame, min_width=128, min_height=128) + image_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + results = await execute_task( + cls, + "/proxy/vidu/extend", + TaskExtendCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"), + images=[image_url] if image_url else None, + ), + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +def _generate_frame_inputs(count: int) -> list: + """Generate input widgets for a given number of frames.""" + inputs = [] + for i in range(1, count + 1): + inputs.extend( + [ + IO.String.Input( + f"prompt{i}", + multiline=True, + default="", + tooltip=f"Text prompt for frame {i} transition.", + ), + IO.Image.Input( + f"end_image{i}", + tooltip=f"End frame image for segment {i}. Aspect ratio must be between 1:4 and 4:1.", + ), + IO.Int.Input( + f"duration{i}", + default=4, + min=2, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip=f"Duration for segment {i} in seconds.", + ), + ] + ) + return inputs + + +class ViduMultiFrameVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduMultiFrameVideoNode", + display_name="Vidu Multi-Frame Video Generation", + category="api node/video/Vidu", + description="Generate a video with multiple keyframe transitions.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "start_image", + tooltip="The starting frame image. Aspect ratio must be between 1:4 and 4:1.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.DynamicCombo.Input( + "frames", + options=[ + IO.DynamicCombo.Option("2", _generate_frame_inputs(2)), + IO.DynamicCombo.Option("3", _generate_frame_inputs(3)), + IO.DynamicCombo.Option("4", _generate_frame_inputs(4)), + IO.DynamicCombo.Option("5", _generate_frame_inputs(5)), + IO.DynamicCombo.Option("6", _generate_frame_inputs(6)), + IO.DynamicCombo.Option("7", _generate_frame_inputs(7)), + IO.DynamicCombo.Option("8", _generate_frame_inputs(8)), + IO.DynamicCombo.Option("9", _generate_frame_inputs(9)), + ], + tooltip="Number of keyframe transitions (2-9).", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model", + "resolution", + "frames", + "frames.duration1", + "frames.duration2", + "frames.duration3", + "frames.duration4", + "frames.duration5", + "frames.duration6", + "frames.duration7", + "frames.duration8", + "frames.duration9", + ] + ), + expr=""" + ( + $m := widgets.model; + $n := $number(widgets.frames); + $is1080 := widgets.resolution = "1080p"; + $d1 := $lookup(widgets, "frames.duration1"); + $d2 := $lookup(widgets, "frames.duration2"); + $d3 := $n >= 3 ? $lookup(widgets, "frames.duration3") : 0; + $d4 := $n >= 4 ? $lookup(widgets, "frames.duration4") : 0; + $d5 := $n >= 5 ? $lookup(widgets, "frames.duration5") : 0; + $d6 := $n >= 6 ? $lookup(widgets, "frames.duration6") : 0; + $d7 := $n >= 7 ? $lookup(widgets, "frames.duration7") : 0; + $d8 := $n >= 8 ? $lookup(widgets, "frames.duration8") : 0; + $d9 := $n >= 9 ? $lookup(widgets, "frames.duration9") : 0; + $totalDuration := $d1 + $d2 + $d3 + $d4 + $d5 + $d6 + $d7 + $d8 + $d9; + $contains($m, "pro") + ? ( + $base := $is1080 ? 0.3 : 0.15; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $n * $base + $perSec * $totalDuration} + ) + : ( + $base := $is1080 ? 0.2 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $n * $base + $perSec * $totalDuration} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + start_image: Input.Image, + seed: int, + resolution: str, + frames: dict, + ) -> IO.NodeOutput: + validate_image_aspect_ratio(start_image, (1, 4), (4, 1)) + frame_count = int(frames["frames"]) + image_settings: list[FrameSetting] = [] + for i in range(1, frame_count + 1): + validate_image_aspect_ratio(frames[f"end_image{i}"], (1, 4), (4, 1)) + validate_string(frames[f"prompt{i}"], max_length=2000) + start_image_url = await upload_image_to_comfyapi( + cls, + start_image, + mime_type="image/png", + wait_label="Uploading start image", + ) + for i in range(1, frame_count + 1): + image_settings.append( + FrameSetting( + prompt=frames[f"prompt{i}"], + key_image=await upload_image_to_comfyapi( + cls, + frames[f"end_image{i}"], + mime_type="image/png", + wait_label=f"Uploading end image({i})", + ), + duration=frames[f"duration{i}"], + ) + ) + results = await execute_task( + cls, + "/proxy/vidu/multiframe", + TaskMultiFrameCreationRequest( + model=model, + seed=seed, + resolution=resolution, + start_image=start_image_url, + image_settings=image_settings, + ), + max_poll_attempts=480 * frame_count, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3TextToVideoNode", + display_name="Vidu Q3 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "3:4", "4:3", "1:1"], + tooltip="The aspect ratio of the output video.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $base := $lookup({"720p": 0.075, "1080p": 0.1}, $res); + $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); + {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + aspect_ratio=model["aspect_ratio"], + resolution=model["resolution"], + audio=model["audio"], + ), + max_poll_attempts=640, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3ImageToVideoNode", + display_name="Vidu Q3 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p", "2K"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res); + $perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res); + {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + image: Input.Image, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + audio=model["audio"], + images=[await upload_image_to_comfyapi(cls, image)], + ), + max_poll_attempts=720, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + class ViduExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -952,6 +1493,10 @@ class ViduExtension(ComfyExtension): Vidu2ImageToVideoNode, Vidu2ReferenceVideoNode, Vidu2StartEndToVideoNode, + ViduExtendVideoNode, + ViduMultiFrameVideoNode, + Vidu3TextToVideoNode, + Vidu3ImageToVideoNode, ] From 6ea8c128a3770763f150e97aca4b29bd2aec60ad Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 30 Jan 2026 23:22:05 -0800 Subject: [PATCH 014/215] Assets Part 2 - add more endpoints (#12125) --- app/assets/api/routes.py | 414 +++++++++- app/assets/api/schemas_in.py | 196 ++++- app/assets/api/schemas_out.py | 33 + app/assets/database/queries.py | 719 +++++++++++++++++- app/assets/helpers.py | 97 ++- app/assets/manager.py | 401 +++++++++- app/assets/scanner.py | 36 +- tests-unit/assets_test/conftest.py | 271 +++++++ .../assets_test/test_assets_missing_sync.py | 348 +++++++++ tests-unit/assets_test/test_crud.py | 306 ++++++++ tests-unit/assets_test/test_downloads.py | 166 ++++ tests-unit/assets_test/test_list_filter.py | 342 +++++++++ .../assets_test/test_metadata_filters.py | 395 ++++++++++ .../assets_test/test_prune_orphaned_assets.py | 141 ++++ tests-unit/assets_test/test_tags.py | 225 ++++++ tests-unit/assets_test/test_uploads.py | 281 +++++++ tests-unit/requirements.txt | 1 + 17 files changed, 4347 insertions(+), 25 deletions(-) create mode 100644 tests-unit/assets_test/conftest.py create mode 100644 tests-unit/assets_test/test_assets_missing_sync.py create mode 100644 tests-unit/assets_test/test_crud.py create mode 100644 tests-unit/assets_test/test_downloads.py create mode 100644 tests-unit/assets_test/test_list_filter.py create mode 100644 tests-unit/assets_test/test_metadata_filters.py create mode 100644 tests-unit/assets_test/test_prune_orphaned_assets.py create mode 100644 tests-unit/assets_test/test_tags.py create mode 100644 tests-unit/assets_test/test_uploads.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 30e87a898..7676e50b4 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,5 +1,8 @@ import logging import uuid +import urllib.parse +import os +import contextlib from aiohttp import web from pydantic import ValidationError @@ -8,6 +11,9 @@ import app.assets.manager as manager from app import user_manager from app.assets.api import schemas_in from app.assets.helpers import get_query_dict +from app.assets.scanner import seed_assets + +import folder_paths ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None @@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. + def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: global USER_MANAGER USER_MANAGER = user_manager_instance @@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response: return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) +@ROUTES.head("/api/assets/hash/{hash}") +async def head_asset_by_hash(request: web.Request) -> web.Response: + hash_str = request.match_info.get("hash", "").strip().lower() + if not hash_str or ":" not in hash_str: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = hash_str.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + exists = manager.asset_exists(asset_hash=hash_str) + return web.Response(status=200 if exists else 404) + + @ROUTES.get("/api/assets") async def list_assets(request: web.Request) -> web.Response: """ @@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response: order=q.order, owner_id=USER_MANAGER.get_request_user_id(request), ) - return web.json_response(payload.model_dump(mode="json")) + return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") @@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +async def download_asset_content(request: web.Request) -> web.Response: + # question: do we need disposition? could we just stick with one of these? + disposition = request.query.get("disposition", "attachment").lower().strip() + if disposition not in {"inline", "attachment"}: + disposition = "attachment" + + try: + abs_path, content_type, filename = manager.resolve_asset_content_for_download( + asset_info_id=str(uuid.UUID(request.match_info["id"])), + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + except NotImplementedError as nie: + return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + except FileNotFoundError: + return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + + quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") + cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + + file_size = os.path.getsize(abs_path) + logging.info( + "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + abs_path, + file_size, + file_size / (1024 * 1024), + content_type, + filename, + ) + + async def file_sender(): + chunk_size = 64 * 1024 + with open(abs_path, "rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + yield chunk + + return web.Response( + body=file_sender(), + content_type=content_type, + headers={ + "Content-Disposition": cd, + "Content-Length": str(file_size), + }, + ) + + +@ROUTES.post("/api/assets/from-hash") +async def create_asset_from_hash(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + result = manager.create_asset_from_hash( + hash_str=body.hash, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return web.json_response(result.model_dump(mode="json"), status=201) + + +@ROUTES.post("/api/assets") +async def upload_asset(request: web.Request) -> web.Response: + """Multipart/form-data endpoint for Asset uploads.""" + if not (request.content_type or "").lower().startswith("multipart/"): + return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + + if s: + if ":" not in s: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + provided_hash = f"{algo}:{digest}" + try: + provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) + except Exception: + provided_hash_exists = None # do not fail the whole request here + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # If client supplied a hash that we know exists, drain but do not write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") + continue # Do not create temp file; we will create AssetInfo from the existing content + + # Otherwise, store to temp for hashing/ingest + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + try: + if os.path.exists(tmp_path or ""): + os.remove(tmp_path) + finally: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + # If client did not send file, and we are not doing a from-hash fast path -> error + if not file_present and not (provided_hash and provided_hash_exists): + return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") + + if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): + # Empty upload is only acceptable if we are fast-pathing from existing hash + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + try: + spec = schemas_in.UploadAssetSpec.model_validate({ + "tags": tags_raw, + "name": provided_name, + "user_metadata": user_metadata_raw, + "hash": provided_hash, + }) + except ValidationError as ve: + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _validation_error_response("INVALID_BODY", ve) + + # Validate models category against configured folders (consistent with previous behavior) + if spec.tags and spec.tags[0] == "models": + if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response( + 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" + ) + + owner_id = USER_MANAGER.get_request_user_id(request) + + # Fast path: if a valid provided hash exists, create AssetInfo without writing anything + if spec.hash and provided_hash_exists is True: + try: + result = manager.create_asset_from_hash( + hash_str=spec.hash, + name=spec.name or (spec.hash.split(":", 1)[1]), + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + owner_id=owner_id, + ) + except Exception: + logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") + + # Drain temp if we accidentally saved (e.g., hash field came after file) + if tmp_path and os.path.exists(tmp_path): + with contextlib.suppress(Exception): + os.remove(tmp_path) + + status = 200 if (not result.created_new) else 201 + return web.json_response(result.model_dump(mode="json"), status=status) + + # Otherwise, we must have a temp file path to ingest + if not tmp_path or not os.path.exists(tmp_path): + # The only case we reach here without a temp file is: client sent a hash that does not exist and no file + return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") + + try: + created = manager.upload_asset_from_temp_path( + spec, + temp_path=tmp_path, + client_filename=file_client_name, + owner_id=owner_id, + expected_asset_hash=spec.hash, + ) + status = 201 if created.created_new else 200 + return web.json_response(created.model_dump(mode="json"), status=status) + except ValueError as e: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + msg = str(e) + if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": + return _error_response( + 400, + "HASH_MISMATCH", + "Uploaded file hash does not match provided hash.", + ) + return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except Exception: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") +async def update_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.update_asset( + asset_info_id=asset_info_id, + name=body.name, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "update_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") +async def delete_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + delete_content = request.query.get("delete_content") + delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} + + try: + deleted = manager.delete_asset_reference( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + delete_content_if_orphan=delete_content, + ) + except Exception: + logging.exception( + "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if not deleted: + return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return web.Response(status=204) + + @ROUTES.get("/api/tags") async def get_tags(request: web.Request) -> web.Response: """ @@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) return web.json_response(result.model_dump(mode="json")) + + +@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def add_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsAdd.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.add_tags_to_asset( + asset_info_id=asset_info_id, + tags=data.tags, + origin="manual", + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def delete_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsRemove.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.remove_tags_from_asset( + asset_info_id=asset_info_id, + tags=data.tags, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.post("/api/assets/seed") +async def seed_assets_endpoint(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output).""" + try: + payload = await request.json() + roots = payload.get("roots", ["models", "input", "output"]) + except Exception: + roots = ["models", "input", "output"] + + valid_roots = [r for r in roots if r in ("models", "input", "output")] + if not valid_roots: + return _error_response(400, "INVALID_BODY", "No valid roots specified") + + try: + seed_assets(tuple(valid_roots)) + except Exception: + logging.exception("seed_assets failed for roots=%s", valid_roots) + return _error_response(500, "INTERNAL", "Seed operation failed") + + return web.json_response({"seeded": valid_roots}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 200b41aef..6707ffb0c 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Any, Literal from pydantic import ( @@ -8,9 +7,9 @@ from pydantic import ( Field, conint, field_validator, + model_validator, ) - class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel): return None +class UpdateAssetBody(BaseModel): + name: str | None = None + user_metadata: dict[str, Any] | None = None + + @model_validator(mode="after") + def _at_least_one(self): + if self.name is None and self.user_metadata is None: + raise ValueError("Provide at least one of: name, user_metadata.") + return self + + +class CreateFromHashBody(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + hash: str + name: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + + @field_validator("hash") + @classmethod + def _require_blake3(cls, v): + s = (v or "").strip().lower() + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return s + + @field_validator("tags", mode="before") + @classmethod + def _tags_norm(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set() + dedup = [] + for t in out: + if t not in seen: + seen.add(t) + dedup.append(t) + return dedup + if isinstance(v, str): + return [t.strip().lower() for t in v.split(",") if t.strip()] + return [] + + class TagsListQuery(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) @@ -75,20 +125,140 @@ class TagsListQuery(BaseModel): return v.lower() or None -class SetPreviewBody(BaseModel): - """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" - preview_id: str | None = None +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) - @field_validator("preview_id", mode="before") + @field_validator("tags") @classmethod - def _norm_uuid(cls, v): + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass + + +class UploadAssetSpec(BaseModel): + """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); + if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + - name: display name + - user_metadata: arbitrary JSON object (optional) + - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + + Files created via this endpoint are stored on disk using the **content hash** as the filename stem + and the original extension is preserved when available. + """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + tags: list[str] = Field(..., min_length=1) + name: str | None = Field(default=None, max_length=512, description="Display Name") + user_metadata: dict[str, Any] = Field(default_factory=dict) + hash: str | None = Field(default=None) + + @field_validator("hash", mode="before") + @classmethod + def _parse_hash(cls, v): if v is None: return None - s = str(v).strip() + s = str(v).strip().lower() if not s: return None - try: - uuid.UUID(s) - except Exception: - raise ValueError("preview_id must be a UUID") - return s + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return f"{algo}:{digest}" + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + """ + Accepts a list of strings (possibly multiple form fields), + where each string can be: + - JSON array (e.g., '["models","loras","foo"]') + - comma-separated ('models, loras, foo') + - single token ('models') + Returns a normalized, deduplicated, ordered list. + """ + items: list[str] = [] + if v is None: + return [] + if isinstance(v, str): + v = [v] + + if isinstance(v, list): + for item in v: + if item is None: + continue + s = str(item).strip() + if not s: + continue + if s.startswith("["): + try: + arr = json.loads(s) + if isinstance(arr, list): + items.extend(str(x) for x in arr) + continue + except Exception: + pass # fallback to CSV parse below + items.extend([p for p in s.split(",") if p.strip()]) + else: + return [] + + # normalize + dedupe + norm = [] + seen = set() + for t in items: + tnorm = str(t).strip().lower() + if tnorm and tnorm not in seen: + seen.add(tnorm) + norm.append(tnorm) + return norm + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v or {} + if isinstance(v, str): + s = v.strip() + if not s: + return {} + try: + parsed = json.loads(s) + except Exception as e: + raise ValueError(f"user_metadata must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("user_metadata must be a JSON object") + return parsed + return {} + + @model_validator(mode="after") + def _validate_order(self): + if not self.tags: + raise ValueError("tags must be provided and non-empty") + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError("models uploads require a category tag as the second tag") + return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 9f8184f20..b6fb3da0c 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -29,6 +29,21 @@ class AssetsList(BaseModel): has_more: bool +class AssetUpdated(BaseModel): + id: str + name: str + asset_hash: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + updated_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("updated_at") + def _ser_updated(self, v: datetime | None, _info): + return v.isoformat() if v else None + + class AssetDetail(BaseModel): id: str name: str @@ -48,6 +63,10 @@ class AssetDetail(BaseModel): return v.isoformat() if v else None +class AssetCreated(AssetDetail): + created_new: bool + + class TagUsage(BaseModel): name: str count: int @@ -58,3 +77,17 @@ class TagsList(BaseModel): tags: list[TagUsage] = Field(default_factory=list) total: int has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py index 0824c0c2f..d6b33ec7b 100644 --- a/app/assets/database/queries.py +++ b/app/assets/database/queries.py @@ -1,9 +1,17 @@ +import os +import logging import sqlalchemy as sa from collections import defaultdict -from sqlalchemy import select, exists, func +from datetime import datetime +from typing import Iterable, Any +from sqlalchemy import select, delete, exists, func +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag -from app.assets.helpers import escape_like_prefix, normalize_tags +from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.helpers import ( + compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow +) from typing import Sequence @@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: return AssetInfo.owner_id.in_(["", owner_id]) +def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path + + def apply_tag_filters( stmt: sa.sql.Select, include_tags: Sequence[str] | None = None, @@ -42,6 +66,7 @@ def apply_tag_filters( ) return stmt + def apply_metadata_filter( stmt: sa.sql.Select, metadata_filter: dict | None = None, @@ -94,7 +119,11 @@ def apply_metadata_filter( return stmt -def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: +def asset_exists_by_hash( + session: Session, + *, + asset_hash: str, +) -> bool: """ Check if an asset with a given hash exists in database. """ @@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: ).first() return row is not None -def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None: + +def asset_info_exists_for_asset_id( + session: Session, + *, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_id == asset_id) + .limit(1) + ) + return (session.execute(q)).first() is not None + + +def get_asset_by_hash( + session: Session, + *, + asset_hash: str, +) -> Asset | None: + return ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + + +def get_asset_info_by_id( + session: Session, + *, + asset_info_id: str, +) -> AssetInfo | None: return session.get(AssetInfo, asset_info_id) + def list_asset_infos_page( session: Session, owner_id: str = "", @@ -171,12 +230,14 @@ def list_asset_infos_page( select(AssetInfoTag.asset_info_id, Tag.name) .join(Tag, Tag.name == AssetInfoTag.tag_name) .where(AssetInfoTag.asset_info_id.in_(id_list)) + .order_by(AssetInfoTag.added_at) ) for aid, tag_name in rows.all(): tag_map[aid].append(tag_name) return infos, tag_map, total + def fetch_asset_info_asset_and_tags( session: Session, asset_info_id: str, @@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags( tags.append(tag_name) return first_info, first_asset, tags + +def fetch_asset_info_and_asset( + session: Session, + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[AssetInfo, Asset] | None: + stmt = ( + select(AssetInfo, Asset) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetInfo.tags)) + ) + row = session.execute(stmt) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + +def list_cache_states_by_asset_id( + session: Session, *, asset_id: str +) -> Sequence[AssetCacheState]: + return ( + session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +def touch_asset_info_by_id( + session: Session, + *, + asset_info_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def create_asset_info_for_existing_asset( + session: Session, + *, + asset_hash: str, + name: str, + user_metadata: dict | None = None, + tags: Sequence[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetInfo: + """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" + now = utcnow() + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"Unknown asset hash {asset_hash}") + + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + try: + with session.begin_nested(): + session.add(info) + session.flush() + except IntegrityError: + existing = ( + session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).unique().scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing + + # metadata["filename"] hack + new_meta = dict(user_metadata or {}) + computed_filename = None + try: + p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + if computed_filename: + new_meta["filename"] = computed_filename + if new_meta: + replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + ) + return info + + +def set_asset_info_tags( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + session.flush() + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +def replace_asset_info_metadata_projection( + session: Session, + *, + asset_info_id: str, + user_metadata: dict | None = None, +) -> None: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = utcnow() + session.flush() + + session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def ingest_fs_asset( + session: Session, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: dict | None = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> dict: + """ + Idempotently upsert: + - Asset by content hash (create if missing) + - AssetCacheState(file_path) pointing to asset_id + - Optionally AssetInfo + tag links and metadata projection + Returns flags and ids. + """ + locator = os.path.abspath(abs_path) + now = utcnow() + + if preview_id: + if not session.get(Asset, preview_id): + preview_id = None + + out: dict[str, Any] = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + } + + # 1) Asset by hash + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + vals = { + "hash": asset_hash, + "size_bytes": int(size_bytes), + "mime_type": mime_type, + "created_at": now, + } + res = session.execute( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + if int(res.rowcount or 0) > 0: + out["asset_created"] = True + asset = ( + session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() + if not asset: + raise RuntimeError("Asset row not found after upsert.") + else: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + out["asset_updated"] = True + + # 2) AssetCacheState upsert by file_path (unique) + vals = { + "asset_id": asset.id, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + ins = ( + sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + + res = session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_id != asset.id, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), + ) + ) + .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) + ) + res2 = session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True + + # 3) Optional AssetInfo + tags + metadata + if info_name: + try: + with session.begin_nested(): + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + session.flush() + out["asset_info_id"] = info.id + except IntegrityError: + pass + + existing_info = ( + session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), + ) + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_id and existing_info.preview_id != preview_id: + existing_info.preview_id = preview_id + + existing_info.updated_at = now + if existing_info.last_access_time < now: + existing_info.last_access_time = now + session.flush() + out["asset_info_id"] = existing_info.id + + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + ensure_tags_exist(session, norm, tag_type="user") + + existing_tag_names = set( + name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + existing_links = set( + tag_name + for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_at=now, + ) + for t in to_add + ] + ) + session.flush() + + # metadata["filename"] hack + if out["asset_info_id"] is not None: + primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) + computed_filename = compute_relative_filename(primary_path) if primary_path else None + + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + return out + + +def update_asset_info_full( + session: Session, + *, + asset_info_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: dict | None = None, + tag_origin: str = "manual", + asset_info_row: Any = None, +) -> AssetInfo: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + computed_filename = None + try: + p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename + replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = utcnow() + session.flush() + + return info + + +def delete_asset_info_by_id( + session: Session, + *, + asset_info_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + return int((session.execute(stmt)).rowcount or 0) > 0 + + def list_tags_with_usage( session: Session, prefix: str | None = None, @@ -265,3 +814,163 @@ def list_tags_with_usage( rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] return rows_norm, int(total or 0) + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: + return [ + tag_name for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +def add_tags_to_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row: Any = None, +) -> dict: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +def remove_tags_from_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) + + +def set_asset_info_preview( + session: Session, + *, + asset_info_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + if preview_asset_id is None: + info.preview_id = None + else: + # validate preview asset exists + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + info.preview_id = preview_asset_id + + info.updated_at = utcnow() + session.flush() diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 08b465b5a..5030b123a 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -1,5 +1,6 @@ import contextlib import os +from decimal import Decimal from aiohttp import web from datetime import datetime, timezone from pathlib import Path @@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: targets.append((name, paths)) return targets +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + root = tags[0] + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + else: + base_dir = os.path.abspath( + folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() + ) + raw_subdirs = tags[1:] + for i in raw_subdirs: + if i in (".", ".."): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + +def ensure_within_base(candidate: str, base: str) -> None: + cand_abs = os.path.abspath(candidate) + base_abs = os.path.abspath(base) + try: + if os.path.commonpath([cand_abs, base_abs]) != base_abs: + raise ValueError("destination escapes base directory") + except Exception: + raise ValueError("invalid destination path") + def compute_relative_filename(file_path: str) -> str | None: """ Return the model's path relative to the last well-known folder (the model category), @@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None: return "/".join(inside) return "/".join(parts) # input/output: keep all parts - def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: """Given an absolute or relative file path, determine which root category the path belongs to: - 'input' if the file resides under `folder_paths.get_input_directory()` @@ -215,3 +249,64 @@ def collect_models_files() -> list[str]: if allowed: out.append(abs_path) return out + +def is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + +def project_kv(key: str, value): + """ + Turn a metadata key/value into typed projection rows. + Returns list[dict] with keys: + key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) + """ + rows: list[dict] = [] + + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + + if value is None: + rows.append(_null_row(0)) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append(_null_row(i)) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/assets/manager.py b/app/assets/manager.py index 6425e7aa2..a68c8c8ae 100644 --- a/app/assets/manager.py +++ b/app/assets/manager.py @@ -1,13 +1,33 @@ +import os +import mimetypes +import contextlib from typing import Sequence from app.database.db import create_session -from app.assets.api import schemas_out +from app.assets.api import schemas_out, schemas_in from app.assets.database.queries import ( asset_exists_by_hash, + asset_info_exists_for_asset_id, + get_asset_by_hash, + get_asset_info_by_id, fetch_asset_info_asset_and_tags, + fetch_asset_info_and_asset, + create_asset_info_for_existing_asset, + touch_asset_info_by_id, + update_asset_info_full, + delete_asset_info_by_id, + list_cache_states_by_asset_id, list_asset_infos_page, list_tags_with_usage, + get_asset_tags, + add_tags_to_asset_info, + remove_tags_from_asset_info, + pick_best_live_path, + ingest_fs_asset, + set_asset_info_preview, ) +from app.assets.helpers import resolve_destination_from_tags, ensure_within_base +from app.assets.database.models import Asset def _safe_sort_field(requested: str | None) -> str: @@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str: return "created_at" -def asset_exists(asset_hash: str) -> bool: +def _get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +def _safe_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + if n: + return n + return fallback + + +def asset_exists(*, asset_hash: str) -> bool: + """ + Check if an asset with a given hash exists in database. + """ with create_session() as session: return asset_exists_by_hash(session, asset_hash=asset_hash) + def list_assets( + *, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, name_contains: str | None = None, @@ -63,7 +100,6 @@ def list_assets( size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, - preview_url=f"/api/assets/{info.id}/content", created_at=info.created_at, updated_at=info.updated_at, last_access_time=info.last_access_time, @@ -76,7 +112,12 @@ def list_assets( has_more=(offset + len(summaries)) < total, ) -def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail: + +def get_asset( + *, + asset_info_id: str, + owner_id: str = "", +) -> schemas_out.AssetDetail: with create_session() as session: res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) if not res: @@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail last_access_time=info.last_access_time, ) + +def resolve_asset_content_for_download( + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[str, str, str]: + with create_session() as session: + pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not pair: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info, asset = pair + states = list_cache_states_by_asset_id(session, asset_id=asset.id) + abs_path = pick_best_live_path(states) + if not abs_path: + raise FileNotFoundError + + touch_asset_info_by_id(session, asset_info_id=asset_info_id) + session.commit() + + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name + + +def upload_asset_from_temp_path( + spec: schemas_in.UploadAssetSpec, + *, + temp_path: str, + client_filename: str | None = None, + owner_id: str = "", + expected_asset_hash: str | None = None, +) -> schemas_out.AssetCreated: + """ + Create new asset or update existing asset from a temporary file path. + """ + try: + # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment + import app.assets.hashing as hashing + digest = hashing.blake3_hash(temp_path) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): + raise ValueError("HASH_MISMATCH") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + info = create_asset_info_for_existing_asset( + session, + asset_hash=asset_hash, + name=display_name, + user_metadata=spec.user_metadata or {}, + tags=spec.tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = get_asset_tags(session, asset_info_id=info.id) + session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=existing.hash, + size=int(existing.size_bytes) if existing.size_bytes is not None else None, + mime_type=existing.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + base_dir, subdirs = resolve_destination_from_tags(spec.tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or spec.name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + ensure_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + with create_session() as session: + result = ingest_fs_asset( + session, + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=spec.user_metadata or {}, + tags=spec.tags, + tag_origin="manual", + require_existing_tags=False, + ) + info_id = result["asset_info_id"] + if not info_id: + raise RuntimeError("failed to create asset metadata") + + pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + info, asset = pair + tag_names = get_asset_tags(session, asset_info_id=info.id) + created_result = schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=result["asset_created"], + ) + session.commit() + + return created_result + + +def update_asset( + *, + asset_info_id: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> schemas_out.AssetUpdated: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + info = update_asset_info_full( + session, + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + tag_origin="manual", + asset_info_row=info_row, + ) + + tag_names = get_asset_tags(session, asset_info_id=asset_info_id) + result = schemas_out.AssetUpdated( + id=info.id, + name=info.name, + asset_hash=info.asset.hash if info.asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + updated_at=info.updated_at, + ) + session.commit() + + return result + + +def set_asset_preview( + *, + asset_info_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> schemas_out.AssetDetail: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_asset_id, + ) + + res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise RuntimeError("State changed during preview update") + info, asset, tags = res + result = schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + session.commit() + + return result + + +def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_id = info_row.asset_id if info_row else None + deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not deleted: + session.commit() + return False + + if not delete_content_if_orphan or not asset_id: + session.commit() + return True + + still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + states = list_cache_states_by_asset_id(session, asset_id=asset_id) + file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + return True + + +def create_asset_from_hash( + *, + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> schemas_out.AssetCreated | None: + canonical = hash_str.strip().lower() + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + info = create_asset_info_for_existing_asset( + session, + asset_hash=canonical, + name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = get_asset_tags(session, asset_info_id=info.id) + result = schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + session.commit() + + return result + + +def add_tags_to_asset( + *, + asset_info_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> schemas_out.TagsAdd: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + data = add_tags_to_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=origin, + create_if_missing=True, + asset_info_row=info_row, + ) + session.commit() + return schemas_out.TagsAdd(**data) + + +def remove_tags_from_asset( + *, + asset_info_id: str, + tags: list[str], + owner_id: str = "", +) -> schemas_out.TagsRemove: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + data = remove_tags_from_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + ) + session.commit() + return schemas_out.TagsRemove(**data) + + def list_tags( prefix: str | None = None, limit: int = 100, diff --git a/app/assets/scanner.py b/app/assets/scanner.py index a16e41d94..0172a5c2f 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No t_start = time.perf_counter() created = 0 skipped_existing = 0 + orphans_pruned = 0 paths: list[str] = [] try: existing_paths: set[str] = set() @@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No except Exception as e: logging.exception("fast DB scan failed for %s: %s", r, e) + try: + orphans_pruned = _prune_orphaned_assets(roots) + except Exception as e: + logging.exception("orphan pruning failed: %s", e) + if "models" in roots: paths.extend(collect_models_files()) if "input" in roots: @@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No finally: if enable_logging: logging.info( - "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)", + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", roots, time.perf_counter() - t_start, created, skipped_existing, + orphans_pruned, len(paths), ) +def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: + """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" + all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] + if not all_prefixes: + return 0 + + def make_prefix_condition(prefix: str): + base = prefix if prefix.endswith(os.sep) else prefix + os.sep + escaped, esc = escape_like_prefix(base) + return AssetCacheState.file_path.like(escaped + "%", escape=esc) + + matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) + + orphan_subq = ( + sqlalchemy.select(Asset.id) + .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) + .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) + ).scalar_subquery() + + with create_session() as sess: + sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) + sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) + result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) + sess.commit() + return result.rowcount + + def _fast_db_consistency_pass( root: RootType, *, diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py new file mode 100644 index 000000000..0a57dd7b5 --- /dev/null +++ b/tests-unit/assets_test/conftest.py @@ -0,0 +1,271 @@ +import contextlib +import json +import os +import socket +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Callable, Iterator, Optional + +import pytest +import requests + + +def pytest_addoption(parser: pytest.Parser) -> None: + """ + Allow overriding the database URL used by the spawned ComfyUI process. + Priority: + 1) --db-url command line option + 2) ASSETS_TEST_DB_URL environment variable (used by CI) + 3) default: None (will use file-backed sqlite in temp dir) + """ + parser.addoption( + "--db-url", + action="store", + default=os.environ.get("ASSETS_TEST_DB_URL"), + help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)", + ) + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _make_base_dirs(root: Path) -> None: + for sub in ("models", "custom_nodes", "input", "output", "temp", "user"): + (root / sub).mkdir(parents=True, exist_ok=True) + + +def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None: + start = time.time() + last_err = None + while time.time() - start < timeout: + try: + r = session.get(base + "/api/assets", timeout=5) + if r.status_code in (200, 400): + return + except Exception as e: + last_err = e + time.sleep(0.25) + raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}") + + +@pytest.fixture(scope="session") +def comfy_tmp_base_dir() -> Path: + env_base = os.environ.get("ASSETS_TEST_BASE_DIR") + created_by_fixture = False + if env_base: + tmp = Path(env_base) + tmp.mkdir(parents=True, exist_ok=True) + else: + tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + created_by_fixture = True + _make_base_dirs(tmp) + yield tmp + if created_by_fixture: + with contextlib.suppress(Exception): + for p in sorted(tmp.rglob("*"), reverse=True): + if p.is_file() or p.is_symlink(): + p.unlink(missing_ok=True) + for p in sorted(tmp.glob("**/*"), reverse=True): + with contextlib.suppress(Exception): + p.rmdir() + tmp.rmdir() + + +@pytest.fixture(scope="session") +def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest): + """ + Boot ComfyUI subprocess with: + - sandbox base dir + - file-backed sqlite DB in temp dir + - autoscan disabled + Returns (base_url, process, port) + """ + port = _free_port() + db_url = request.config.getoption("--db-url") + if not db_url: + # Use a file-backed sqlite database in the temp directory + db_path = comfy_tmp_base_dir / "assets-test.sqlite3" + db_url = f"sqlite:///{db_path}" + + logs_dir = comfy_tmp_base_dir / "logs" + logs_dir.mkdir(exist_ok=True) + out_log = open(logs_dir / "stdout.log", "w", buffering=1) + err_log = open(logs_dir / "stderr.log", "w", buffering=1) + + comfy_root = Path(__file__).resolve().parent.parent.parent + if not (comfy_root / "main.py").is_file(): + raise FileNotFoundError(f"main.py not found under {comfy_root}") + + proc = subprocess.Popen( + args=[ + sys.executable, + "main.py", + f"--base-directory={str(comfy_tmp_base_dir)}", + f"--database-url={db_url}", + "--disable-assets-autoscan", + "--listen", + "127.0.0.1", + "--port", + str(port), + "--cpu", + ], + stdout=out_log, + stderr=err_log, + cwd=str(comfy_root), + env={**os.environ}, + ) + + for _ in range(50): + if proc.poll() is not None: + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}") + time.sleep(0.1) + + base_url = f"http://127.0.0.1:{port}" + try: + with requests.Session() as s: + _wait_http_ready(base_url, s, timeout=90.0) + yield base_url, proc, port + except Exception as e: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=10) + with contextlib.suppress(Exception): + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI did not become ready: {e}") + + if proc and proc.poll() is None: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=15) + out_log.close() + err_log.close() + + +@pytest.fixture +def http() -> Iterator[requests.Session]: + with requests.Session() as s: + s.timeout = 120 + yield s + + +@pytest.fixture +def api_base(comfy_url_and_proc) -> str: + base_url, _proc, _port = comfy_url_and_proc + return base_url + + +def _post_multipart_asset( + session: requests.Session, + base: str, + *, + name: str, + tags: list[str], + meta: dict, + data: bytes, + extra_fields: Optional[dict] = None, +) -> tuple[int, dict]: + files = {"file": (name, data, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps(meta), + } + if extra_fields: + for k, v in extra_fields.items(): + form_data[k] = v + r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120) + return r.status_code, r.json() + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str, int], bytes]: + def _make(name: str, size: int = 8192) -> bytes: + seed = sum(ord(c) for c in name) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + return _make + + +@pytest.fixture +def asset_factory(http: requests.Session, api_base: str): + """ + Returns create(name, tags, meta, data) -> response dict + Tracks created ids and deletes them after the test. + """ + created: list[str] = [] + + def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: + status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data) + assert status in (200, 201), body + created.append(body["id"]) + return body + + yield create + + for aid in created: + with contextlib.suppress(Exception): + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + + +@pytest.fixture +def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict: + """ + Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/. + Returns response dict with id, asset_hash, tags, etc. + """ + name = "unit_1_example.safetensors" + p = getattr(request, "param", {}) or {} + tags: Optional[list[str]] = p.get("tags") + if tags is None: + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} + files = {"file": (name, b"A" * 4096, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps(meta), + } + r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120) + body = r.json() + assert r.status_code == 201, body + return body + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(http: requests.Session, api_base: str): + """Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test.""" + yield + + while True: + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "limit": "500", "sort": "name"}, + timeout=30, + ) + if r.status_code != 200: + break + body = r.json() + ids = [a["id"] for a in body.get("assets", [])] + if not ids: + break + for aid in ids: + with contextlib.suppress(Exception): + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a fast sync/seed pass by calling the seed endpoint.""" + session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) + time.sleep(0.2) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py new file mode 100644 index 000000000..78fa7b404 --- /dev/null +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -0,0 +1,348 @@ +import os +import uuid +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_seed_asset_removed_when_file_is_deleted( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Asset without hash (seed) whose file disappears: + after triggering sync_seed_assets, Asset + AssetInfo disappear. + """ + # Create a file directly under input/unit-tests/ so tags include "unit-tests" + case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" + case_dir.mkdir(parents=True, exist_ok=True) + name = f"seed_{uuid.uuid4().hex[:8]}.bin" + fp = case_dir / name + fp.write_bytes(b"Z" * 2048) + + # Trigger a seed sync so DB sees this path (seed asset => hash is NULL) + trigger_sync_seed_assets(http, api_base) + + # Verify it is visible via API and carries no hash (seed) + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + timeout=120, + ) + body1 = r1.json() + assert r1.status_code == 200 + # there should be exactly one with that name + matches = [a for a in body1.get("assets", []) if a.get("name") == name] + assert matches + assert matches[0].get("asset_hash") is None + asset_info_id = matches[0]["id"] + + # Remove the underlying file and sync again + if fp.exists(): + fp.unlink() + + trigger_sync_seed_assets(http, api_base) + + # It should disappear (AssetInfo and seed Asset gone) + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + matches2 = [a for a in body2.get("assets", []) if a.get("name") == name] + assert not matches2, f"Seed asset {asset_info_id} should be gone after sync" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags") +def test_hashed_asset_missing_tag_added_then_removed_after_scan( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """Hashed asset with a single cache_state: + 1. delete its file -> sync adds 'missing' + 2. restore file -> sync removes 'missing' + """ + name = "missing_tag_test.png" + tags = ["input", "unit-tests", "msync2"] + data = make_asset_bytes(name, 4096) + a = asset_factory(name, tags, {}, data) + + # Compute its on-disk path and remove it + dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png") + assert dest.exists(), f"Expected asset file at {dest}" + dest.unlink() + + # Fast sync should add 'missing' to the AssetInfo + trigger_sync_seed_assets(http, api_base) + + g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion" + + # Restore the file with the exact same content and sync again + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(data) + + trigger_sync_seed_assets(http, api_base) + + g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify" + + +def test_hashed_asset_two_asset_infos_both_get_missing( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Hashed asset with a single cache_state, but two AssetInfo rows: + deleting the single file then syncing should add 'missing' to both infos. + """ + # Upload one hashed asset + name = "two_infos_one_path.png" + base_tags = ["input", "unit-tests", "multiinfo"] + created = asset_factory(name, base_tags, {}, b"A" * 2048) + + # Create second AssetInfo for the same Asset via from-hash + payload = { + "hash": created["asset_hash"], + "name": "two_infos_one_path_copy.png", + "tags": base_tags, # keep it in our unit-tests scope for cleanup + "user_metadata": {"k": "v"}, + } + r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120) + b2 = r2.json() + assert r2.status_code == 201, b2 + second_id = b2["id"] + + # Remove the single underlying file + p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") + assert p.exists() + p.unlink() + + r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags0 = r0.json() + assert r0.status_code == 200, tags0 + byname0 = {t["name"]: t for t in tags0.get("tags", [])} + old_missing = int(byname0.get("missing", {}).get("count", 0)) + + # Sync -> both AssetInfos for this asset must receive 'missing' + trigger_sync_seed_assets(http, api_base) + + ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + da = ga.json() + assert ga.status_code == 200, da + assert "missing" in set(da.get("tags", [])) + + gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120) + db = gb.json() + assert gb.status_code == 200, db + assert "missing" in set(db.get("tags", [])) + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfos) + r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags1 = r1.json() + assert r1.status_code == 200, tags1 + byname1 = {t["name"]: t for t in tags1.get("tags", [])} + new_missing = int(byname1.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with two cache_state rows: + 1. delete one file -> sync should NOT add 'missing' + 2. delete second file -> sync should add 'missing' + """ + name = "two_cache_states_partial_delete.png" + tags = ["input", "unit-tests", "dual"] + data = make_asset_bytes(name, 3072) + + created = asset_factory(name, tags, {}, data) + path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png") + assert path1.exists() + + # Create a second on-disk copy under the same root but different subfolder + path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + + # Fast seed so the second path appears (as a seed initially) + trigger_sync_seed_assets(http, api_base) + + # Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner. + run_scan_and_wait("input") + + # Remove only one file and sync -> asset should still be healthy (no 'missing') + path1.unlink() + trigger_sync_seed_assets(http, api_base) + + g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains" + + # Baseline 'missing' usage count just before last file removal + r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags0 = r0.json() + assert r0.status_code == 200, tags0 + old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0)) + + # Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo + path2.unlink() + trigger_sync_seed_assets(http, api_base) + + g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain" + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset) + r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags1 = r1.json() + assert r1.status_code == 200, tags1 + new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """ + Fast pass alone clears 'missing' when size and mtime match exactly: + 1) upload (hashed), record original mtime_ns + 2) delete -> fast pass adds 'missing' + 3) restore same bytes and set mtime back to the original value + 4) run fast pass again -> 'missing' is removed (no slow scan) + """ + scope = f"fastclear-{uuid.uuid4().hex[:6]}" + name = "fastpass_clear.bin" + data = make_asset_bytes(name, 3072) + + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p = base / get_asset_filename(a["asset_hash"], ".bin") + st0 = p.stat() + orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) + + # Delete -> fast pass adds 'missing' + p.unlink() + trigger_sync_seed_assets(http, api_base) + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Restore same bytes and revert mtime to the original value + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + # set both atime and mtime in ns to ensure exact match + os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns)) + + # Fast pass should clear 'missing' without a scan + trigger_sync_seed_assets(http, api_base) + g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_fastpass_removes_stale_state_row_no_missing( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two states: + - delete one file + - run fast pass only + Expect: + - asset stays healthy (no 'missing') + - stale AssetCacheState row for the deleted path is removed. + We verify this behaviorally by recreating the deleted path and running fast pass again: + a new *seed* AssetInfo is created, which proves the old state row was not reused. + """ + scope = f"stale-{uuid.uuid4().hex[:6]}" + name = "two_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload hashed asset at path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + base = comfy_tmp_base_dir / root / "unit-tests" / scope + a1_filename = get_asset_filename(a["asset_hash"], ".bin") + p1 = base / a1_filename + assert p1.exists() + + aid = a["id"] + h = a["asset_hash"] + + # Create second state path2, seed+scan to dedupe into the same Asset + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Delete path1 and run fast pass -> no 'missing' and stale state row should be removed + p1.unlink() + trigger_sync_seed_assets(http, api_base) + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" not in set(d1.get("tags", [])) + + # Recreate path1 and run fast pass again. + # If the stale state row was removed, a NEW seed AssetInfo will appear for this path. + p1.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + + rl = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}"}, + timeout=120, + ) + bl = rl.json() + assert rl.status_code == 200, bl + items = bl.get("assets", []) + # one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null) + hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)] + assert h in hashes + assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" + + # Asset identity still healthy + rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120) + assert rh.status_code == 200 diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py new file mode 100644 index 000000000..d2b69f475 --- /dev/null +++ b/tests-unit/assets_test/test_crud.py @@ -0,0 +1,306 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def test_create_from_hash_success( + http: requests.Session, api_base: str, seeded_asset: dict +): + h = seeded_asset["asset_hash"] + payload = { + "hash": h, + "name": "from_hash_ok.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "from-hash"], + "user_metadata": {"k": "v"}, + } + r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + assert b1["asset_hash"] == h + assert b1["created_new"] is False + aid = b1["id"] + + # Calling again with the same name should return the same AssetInfo id + r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + b2 = r2.json() + assert r2.status_code == 201, b2 + assert b2["id"] == aid + + +def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # GET detail + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200, detail + assert detail["id"] == aid + assert "user_metadata" in detail + assert "filename" in detail["user_metadata"] + + # DELETE + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET again -> 404 + rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + body = rg2.json() + assert rg2.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_delete_upon_reference_count( + http: requests.Session, api_base: str, seeded_asset: dict +): + # Create a second reference to the same asset via from-hash + src_hash = seeded_asset["asset_hash"] + payload = { + "hash": src_hash, + "name": "unit_ref_copy.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "del-flow"], + "user_metadata": {"note": "copy"}, + } + r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + copy = r2.json() + assert r2.status_code == 201, copy + assert copy["asset_hash"] == src_hash + assert copy["created_new"] is False + + # Delete original reference -> asset identity must remain + aid1 = seeded_asset["id"] + rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) + assert rd1.status_code == 204 + + rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh1.status_code == 200 # identity still present + + # Delete the last reference with default semantics -> identity and cached files removed + aid2 = copy["id"] + rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) + assert rd2.status_code == 204 + + rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh2.status_code == 404 # orphan content removed + + +def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + original_tags = seeded_asset["tags"] + + payload = { + "name": "unit_1_renamed.safetensors", + "user_metadata": {"purpose": "updated", "epoch": 2}, + } + ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120) + body = ru.json() + assert ru.status_code == 200, body + assert body["name"] == payload["name"] + assert body["tags"] == original_tags # tags unchanged + assert body["user_metadata"]["purpose"] == "updated" + # filename should still be present and normalized by server + assert "filename" in body["user_metadata"] + + +def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict): + h = seeded_asset["asset_hash"] + + # Existing + rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120) + assert rh1.status_code == 200 + + # Non-existent + rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120) + assert rh2.status_code == 404 + + +def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str): + # Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload. + # requests exposes an empty body for HEAD, so validate status and that there is no payload. + rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120) + assert rh.status_code == 400 + body = rh.content + assert body == b"" + + +def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): + bogus = str(uuid.uuid4()) + r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) + body = r.json() + assert r.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_create_from_hash_invalids(http: requests.Session, api_base: str): + # Bad hash algorithm + bad = { + "hash": "sha256:" + "0" * 64, + "name": "x.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Invalid JSON body + r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_JSON" + + +def test_get_update_download_bad_ids(http: requests.Session, api_base: str): + # All endpoints should be not found, as we UUID regex directly in the route definition. + bad_id = "not-a-uuid" + + r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) + assert r1.status_code == 404 + + r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) + assert r3.status_code == 404 + + +def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_concurrent_delete_same_asset_info_single_204( + root: str, + http: requests.Session, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Many concurrent DELETE for the same AssetInfo should result in: + - exactly one 204 No Content (the one that actually deleted) + - all others 404 Not Found (row already gone) + """ + scope = f"conc-del-{uuid.uuid4().hex[:6]}" + name = "to_delete.bin" + data = make_asset_bytes(name, 1536) + + created = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = created["id"] + + # Hit the same endpoint N times in parallel. + n_tests = 4 + url = f"{api_base}/api/assets/{aid}?delete_content=false" + + def _do_delete(delete_url): + with requests.Session() as s: + return s.delete(delete_url, timeout=120).status_code + + with ThreadPoolExecutor(max_workers=n_tests) as ex: + statuses = list(ex.map(_do_delete, [url] * n_tests)) + + # Exactly one actual delete, the rest must be 404 + assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}" + assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}" + + # The resource must be gone. + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_metadata_filename_is_set_for_seed_asset_without_hash( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Seed ingest (no hash yet) must compute user_metadata['filename'] immediately.""" + scope = f"seedmeta-{uuid.uuid4().hex[:6]}" + name = "seed_filename.bin" + + base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b" + base.mkdir(parents=True, exist_ok=True) + fp = base / name + fp.write_bytes(b"Z" * 2048) + + trigger_sync_seed_assets(http, api_base) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + timeout=120, + ) + body = r1.json() + assert r1.status_code == 200, body + matches = [a for a in body.get("assets", []) if a.get("name") == name] + assert matches, "Seed asset should be visible after sync" + assert matches[0].get("asset_hash") is None # still a seed + aid = matches[0]["id"] + + r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = r2.json() + assert r2.status_code == 200, detail + filename = (detail.get("user_metadata") or {}).get("filename") + expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/") + assert filename == expected, f"expected filename={expected}, got {filename!r}" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_metadata_filename_computed_and_updated_on_retarget( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + 1) Ingest under {root}/unit-tests//a/b/ -> filename reflects relative path. + 2) Retarget by copying to {root}/unit-tests//x/, remove old file, + run fast pass + scan -> filename updates to new relative path. + """ + scope = f"meta-fn-{uuid.uuid4().hex[:6]}" + name1 = "compute_metadata_filename.png" + name2 = "compute_changed_metadata_filename.png" + data = make_asset_bytes(name1, 2100) + + # Upload into nested path a/b + a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data) + aid = a["id"] + + root_base = comfy_tmp_base_dir / root + p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png")) + assert p1.exists() + + # filename at ingest should be the path relative to root + rel1 = str(p1.relative_to(root_base)).replace("\\", "/") + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + fn1 = d1["user_metadata"].get("filename") + assert fn1 == rel1 + + # Retarget: copy to x/, remove old, then sync+scan + p2 = root_base / "unit-tests" / scope / "x" / name2 + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + if p1.exists(): + p1.unlink() + + trigger_sync_seed_assets(http, api_base) # seed the new path + run_scan_and_wait(root) # verify/hash and reconcile + + # filename should now point at x/ + rel2 = str(p2.relative_to(root_base)).replace("\\", "/") + g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + fn2 = d2["user_metadata"].get("filename") + assert fn2 == rel2 diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py new file mode 100644 index 000000000..cdebf9082 --- /dev/null +++ b/tests-unit/assets_test/test_downloads.py @@ -0,0 +1,166 @@ +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # default attachment + r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + data = r1.content + assert r1.status_code == 200 + cd = r1.headers.get("Content-Disposition", "") + assert "attachment" in cd + assert data and len(data) == 4096 + + # inline requested + r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120) + r2.content + assert r2.status_code == 200 + cd2 = r2.headers.get("Content-Disposition", "") + assert "inline" in cd2 + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_download_chooses_existing_state_and_updates_access_time( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two state paths: if the first one disappears, + GET /content still serves from the remaining path and bumps last_access_time. + """ + scope = f"dl-first-{uuid.uuid4().hex[:6]}" + name = "first_existing_state.bin" + data = make_asset_bytes(name, 3072) + + # Upload -> path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + path1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert path1.exists() + + # Seed path2 by copying, then scan to dedupe into a second state + path2 = base / "alt" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Remove path1 so server must fall back to path2 + path1.unlink() + + # last_access_time before + rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d0 = rg0.json() + assert rg0.status_code == 200, d0 + ts0 = d0.get("last_access_time") + + time.sleep(0.05) + r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + blob = r.content + assert r.status_code == 200 + assert blob == data # must serve from the surviving state (same bytes) + + rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = rg1.json() + assert rg1.status_code == 200, d1 + ts1 = d1.get("last_access_time") + + def _parse_iso8601(s: Optional[str]) -> Optional[float]: + if not s: + return None + s = s[:-1] if s.endswith("Z") else s + return datetime.fromisoformat(s).timestamp() + + t0 = _parse_iso8601(ts0) + t1 = _parse_iso8601(ts1) + assert t1 is not None + if t0 is not None: + assert t1 > t0 + + +@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) +def test_download_missing_file_returns_404( + http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict +): + # Remove the underlying file then attempt download. + # We initialize fixture without additional tags to know exactly the asset file path. + try: + aid = seeded_asset["id"] + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200 + asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors") + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename + assert abs_path.exists() + abs_path.unlink() + + r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + assert r2.status_code == 404 + body = r2.json() + assert body["error"]["code"] == "FILE_NOT_FOUND" + finally: + # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. + dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr.content + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_download_404_if_all_states_missing( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Multi-state asset: after the last remaining on-disk file is removed, download must return 404.""" + scope = f"dl-404-{uuid.uuid4().hex[:6]}" + name = "missing_all_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload -> path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert p1.exists() + + # Seed a second state and dedupe + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Remove first file -> download should still work via the second state + p1.unlink() + ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + b1 = ok1.content + assert ok1.status_code == 200 and b1 == data + + # Remove the last file -> download must 404 + p2.unlink() + r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + body = r2.json() + assert r2.status_code == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py new file mode 100644 index 000000000..82e109832 --- /dev/null +++ b/tests-unit/assets_test/test_list_filter.py @@ -0,0 +1,342 @@ +import time +import uuid + +import requests + + +def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"] + for n in names: + asset_factory( + n, + ["models", "checkpoints", "unit-tests", "paging"], + {"epoch": 1}, + make_asset_bytes(n, size=2048), + ) + + # name ascending for stable order + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == sorted(names)[:2] + assert b1["has_more"] is True + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == sorted(names)[2:] + assert b2["has_more"] is False + + +def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory): + a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) + b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "name_contains": "inc_"}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in body2["assets"]] + assert a["name"] in names2 + assert b["name"] in names2 + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "non-existing-tag"}, + timeout=120, + ) + body3 = r2.json() + assert r2.status_code == 200 + assert not body3["assets"] + + +def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-size"] + n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" + asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) + asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) + asset_factory(n3, t, {}, make_asset_bytes(n3, 3072)) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"}, + timeout=120, + ) + b1 = r1.json() + names = [a["name"] for a in b1["assets"]] + assert names[:3] == [n1, n2, n3] + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"}, + timeout=120, + ) + b2 = r2.json() + names2 = [a["name"] for a in b2["assets"]] + assert names2[:3] == [n3, n2, n1] + + + +def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-upd"] + a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) + a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) + + # Rename the second asset to bump updated_at + rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120) + upd = rp.json() + assert rp.status_code == 200, upd + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == "upd_b_renamed.safetensors" + assert a1["name"] in names + + + +def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-access"] + asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) + time.sleep(0.02) + a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) + + # Touch last_access_time of b by downloading its content + time.sleep(0.02) + dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120) + assert dl.status_code == 200 + dl.content + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == a2["name"] + + +def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-include"] + a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) + asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) + + # CSV + case-insensitive + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + assert not any("beta" in x for x in names1) + + # Repeated query params for include_tags + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-include"), + ("include_tags", "alpha"), + ] + r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a["name"] in names2 + assert not any("beta" in x for x in names2) + + # Duplicates and spaces in CSV + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": " unit-tests , lf-include , alpha , alpha "}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a["name"] in names3 + + +def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-exclude"] + a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) + asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) + + # Exclude uppercase should work + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + # Repeated excludes with duplicates + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-exclude"), + ("exclude_tags", "beta"), + ("exclude_tags", "beta"), + ] + r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert all("beta" not in x for x in names2) + + +def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-name"] + a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) + a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a1["name"] in names1 + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a1["name"] in names2 + + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a2["name"] in names3 + + +def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] + asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) + asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) + asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) + + # Offset far beyond total + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + assert not b1["assets"] + assert b1["has_more"] is False + + # Boundary large limit (<=500 is valid) + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + assert len(b2["assets"]) == 3 + assert b2["has_more"] is False + + +def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): + r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): + # limit too small + r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + # bad metadata JSON + r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +def test_list_assets_name_contains_literal_underscore( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'name_contains' must treat '_' literally, not as a SQL wildcard. + We create: + - foo_bar.safetensors (should match) + - fooxbar.safetensors (must NOT match if '_' is escaped) + - foobar.safetensors (must NOT match) + """ + scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" + tags = ["models", "checkpoints", "unit-tests", scope] + + a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) + b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) + c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700)) + + r = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200, body + names = [x["name"] for x in body["assets"]] + assert a["name"] in names, f"Expected literal underscore match to include {a['name']}" + assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'" + assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'" + assert body["total"] == 1 diff --git a/tests-unit/assets_test/test_metadata_filters.py b/tests-unit/assets_test/test_metadata_filters.py new file mode 100644 index 000000000..20285a3b3 --- /dev/null +++ b/tests-unit/assets_test/test_metadata_filters.py @@ -0,0 +1,395 @@ +import json + + +def test_meta_and_across_keys_and_types( + http, api_base: str, asset_factory, make_asset_bytes +): + name = "mf_and_mix.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-and"] + meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) + + # All keys must match (AND semantics) + f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_ok), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names = [a["name"] for a in b1["assets"]] + assert name in names + + # One key mismatched -> no result + f_bad = {"purpose": "mix", "epoch": 2, "active": True} + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_bad), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + assert not b2["assets"] + + +def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): + name = "mf_types.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-types"] + meta = {"epoch": 1, "active": True} + asset_factory(name, tags, meta, make_asset_bytes(name)) + + # int filter matches numeric + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": 1}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # string "1" must NOT match numeric 1 + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": "1"}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + # bool True matches, string "true" must NOT match + r3 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": True}), + }, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"]) + + r4 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": "true"}), + }, + timeout=120, + ) + b4 = r4.json() + assert r4.status_code == 200 and not b4["assets"] + + +def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_scalars.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-list"] + meta = {"flags": ["red", "green"]} + asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) + + # Any-of should match because "green" is present + filt_ok = {"flags": ["blue", "green"]} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # None of provided flags present -> no match + filt_miss = {"flags": ["blue", "yellow"]} + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + # Duplicates in list should not break matching + filt_dup = {"flags": ["green", "green", "green"]} + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"]) + + +def test_meta_none_semantics_missing_or_null_and_any_of_with_none( + http, api_base, asset_factory, make_asset_bytes +): + # a1: key missing; a2: explicit null; a3: concrete value + t = ["models", "checkpoints", "unit-tests", "mf-none"] + a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) + a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) + a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) + + # Filter {maybe: None} must match a1 and a2, not a3 + filt = {"maybe": None} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + got = [a["name"] for a in b1["assets"]] + assert a1["name"] in got and a2["name"] in got and a3["name"] not in got + + # Any-of with None should include missing/null plus value matches + filt_any = {"maybe": [None, "x"]} + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2 + + +def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): + name = "mf_nested_json.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} + asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) + + # Exact JSON object equality (same structure) + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": cfg}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Different JSON object should not match + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_objects.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] + asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) + + # Any-of for list of objects should match when one element equals the filter object + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Non-matching object -> no match + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): + name = "mf_keys_unicode.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + meta = { + "weird.key": "v1", + "path/like": 7, + "with:colon": True, + "ключ": "значение", + "emoji": "🐍", + } + asset_factory(name, tags, meta, make_asset_bytes(name, 1500)) + + # Match all the special keys + filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Unicode key match + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"]) + + +def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) + a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) + + # count == 0 must match only a0 + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [a["name"] for a in b1["assets"]] + assert a0["name"] in names1 and a1["name"] not in names1 + + # Any-of list of booleans: True matches second asset + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"]) + + +def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): + name = "mf_mixed_list.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + meta = {"mix": ["1", 1, True, None]} + asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) + + # Should match because 1 is present + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Should NOT match for False + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): + # Use a unique scope tag to avoid interference + t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) + y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) + + # Filtering by unknown key with None should return both (missing key OR null) + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names = {a["name"] for a in b1["assets"]} + assert x["name"] in names and y["name"] in names + + # Filtering by unknown key with concrete value should return none + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes): + # alpha matches epoch=1; beta has epoch=2 + a = asset_factory( + "mf_tag_alpha.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + {"epoch": 1}, + make_asset_bytes("alpha"), + ) + b = asset_factory( + "mf_tag_beta.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + {"epoch": 2}, + make_asset_bytes("beta"), + ) + + params = { + "include_tags": "unit-tests,mf-tag,alpha", + "exclude_tags": "beta", + "name_contains": "mf_tag_", + "metadata_filter": json.dumps({"epoch": 1}), + } + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + +def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): + # Three assets in same scope with different sizes and a common filter key + t = ["models", "checkpoints", "unit-tests", "mf-sort"] + n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" + asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) + asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) + asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) + + # Sort by size ascending with paging + q = { + "include_tags": "unit-tests,mf-sort", + "metadata_filter": json.dumps({"group": "g"}), + "sort": "size", "order": "asc", "limit": "2", + } + r1 = http.get(api_base + "/api/assets", params=q, timeout=120) + b1 = r1.json() + assert r1.status_code == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == [n1, n2] + assert b1["has_more"] is True + + q2 = {**q, "offset": "2"} + r2 = http.get(api_base + "/api/assets", params=q2, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == [n3] + assert b2["has_more"] is False diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py new file mode 100644 index 000000000..f602e5a77 --- /dev/null +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -0,0 +1,141 @@ +import uuid +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +@pytest.fixture +def create_seed_file(comfy_tmp_base_dir: Path): + """Create a file on disk that will become a seed asset after sync.""" + created: list[Path] = [] + + def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path: + name = name or f"seed_{uuid.uuid4().hex[:8]}.bin" + path = comfy_tmp_base_dir / root / "unit-tests" / scope / name + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + created.append(path) + return path + + yield _create + + for p in created: + p.unlink(missing_ok=True) + + +@pytest.fixture +def find_asset(http: requests.Session, api_base: str): + """Query API for assets matching scope and optional name.""" + def _find(scope: str, name: str | None = None) -> list[dict]: + params = {"include_tags": f"unit-tests,{scope}"} + if name: + params["name_contains"] = name + r = http.get(f"{api_base}/api/assets", params=params, timeout=120) + assert r.status_code == 200 + assets = r.json().get("assets", []) + if name: + return [a for a in assets if a.get("name") == name] + return assets + + return _find + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_orphaned_seed_asset_is_pruned( + root: str, + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Seed asset with deleted file is removed; with file present, it survives.""" + scope = f"prune-{uuid.uuid4().hex[:6]}" + fp = create_seed_file(root, scope) + name = fp.name + + trigger_sync_seed_assets(http, api_base) + assert find_asset(scope, name), "Seed asset should exist" + + fp.unlink() + trigger_sync_seed_assets(http, api_base) + assert not find_asset(scope, name), "Orphaned seed should be pruned" + + +def test_seed_asset_with_file_survives_prune( + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Seed asset with file still on disk is NOT pruned.""" + scope = f"keep-{uuid.uuid4().hex[:6]}" + fp = create_seed_file("input", scope) + + trigger_sync_seed_assets(http, api_base) + trigger_sync_seed_assets(http, api_base) + + assert find_asset(scope, fp.name), "Seed with valid file should survive" + + +def test_hashed_asset_not_pruned_when_file_missing( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """Hashed assets are never deleted by prune, even without file.""" + scope = f"hashed-{uuid.uuid4().hex[:6]}" + data = make_asset_bytes("test", 2048) + a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data) + + path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin") + path.unlink() + + trigger_sync_seed_assets(http, api_base) + + r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + assert r.status_code == 200, "Hashed asset should NOT be pruned" + + +def test_prune_across_multiple_roots( + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Prune correctly handles assets across input and output roots.""" + scope = f"multi-{uuid.uuid4().hex[:6]}" + input_fp = create_seed_file("input", scope, "input.bin") + create_seed_file("output", scope, "output.bin") + + trigger_sync_seed_assets(http, api_base) + assert len(find_asset(scope)) == 2 + + input_fp.unlink() + trigger_sync_seed_assets(http, api_base) + + remaining = find_asset(scope) + assert len(remaining) == 1 + assert remaining[0]["name"] == "output.bin" + + +@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"]) +def test_special_chars_in_path_escaped_correctly( + dirname: str, + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches.""" + scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}" + fp = create_seed_file("input", scope) + + trigger_sync_seed_assets(http, api_base) + trigger_sync_seed_assets(http, api_base) + + assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive" diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags.py new file mode 100644 index 000000000..6b1047802 --- /dev/null +++ b/tests-unit/assets_test/test_tags.py @@ -0,0 +1,225 @@ +import json +import uuid + +import requests + + +def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict): + # Include zero-usage tags by default + r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120) + body1 = r1.json() + assert r1.status_code == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags before we add anything new from this test cycle + r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120) + body2 = r2.json() + assert r2.status_code == 200 + # We already seeded one asset via fixture, so used tags must be non-empty + used_names = [t["name"] for t in body2["tags"]] + assert "models" in used_names + assert "checkpoints" in used_names + + # Prefix filter should refine the list + r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [t["name"] for t in b3["tags"]] + assert "unit-tests" in names3 + assert "models" not in names3 # filtered out by prefix + + # Order by name ascending should be stable + r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120) + b4 = r4.json() + assert r4.status_code == 200 + names4 = [t["name"] for t in b4["tags"]] + assert names4 == sorted(names4) + + +def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + # Baseline: system tags exist when include_zero (default) is true + r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120) + body1 = r1.json() + assert r1.status_code == 200 + names = [t["name"] for t in body1["tags"]] + assert "models" in names and "checkpoints" in names + + # Create a short-lived asset under input with a unique custom tag + scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" + custom_tag = f"temp-{uuid.uuid4().hex[:8]}" + name = "tag_seed.bin" + _asset = asset_factory( + name, + ["input", "unit-tests", scope, custom_tag], + {}, + make_asset_bytes(name, 512), + ) + + # While the asset exists, the custom tag must appear when include_zero=false + r2 = http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + used_names = [t["name"] for t in body2["tags"]] + assert custom_tag in used_names + + # Delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + assert rd.status_code == 204 + + # Now the custom tag must not be returned when include_zero=false + r3 = http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + timeout=120, + ) + body3 = r3.json() + assert r3.status_code == 200 + names_after = [t["name"] for t in body3["tags"]] + assert custom_tag not in names_after + assert not names_after # filtered view should be empty now + + +def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add tags with duplicates and mixed case + payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120) + b1 = r1.json() + assert r1.status_code == 200, b1 + # normalized, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"newtag", "beta"} + assert set(b1["already_present"]) == {"unit-tests"} + assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + g = rg.json() + assert rg.status_code == 200 + tags_now = set(g["tags"]) + assert {"newtag", "beta"}.issubset(tags_now) + + # Remove a tag and a non-existent tag + payload_del = {"tags": ["newtag", "does-not-exist"]} + r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + assert set(b2["removed"]) == {"newtag"} + assert set(b2["not_present"]) == {"does-not-exist"} + + # Verify remaining tags after deletion + rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + g2 = rg2.json() + assert rg2.status_code == 200 + tags_later = set(g2["tags"]) + assert "newtag" not in tags_later + assert "beta" in tags_later # still present + + +def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + h = seeded_asset["asset_hash"] + + # Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1) + r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120) + add_body = r_add.json() + assert r_add.status_code == 200, add_body + + # Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'. + payload = { + "hash": h, + "name": "order_only_bbb.safetensors", + "tags": ["input", "unit-tests", "orderbbb"], + "user_metadata": {}, + } + r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + copy_body = r_copy.json() + assert r_copy.status_code == 201, copy_body + + # 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa' + # because it has higher usage (2 vs 1). + r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 200, b1 + names1 = [t["name"] for t in b1["tags"]] + counts1 = {t["name"]: t["count"] for t in b1["tags"]} + # Both must be present within the prefix subset + assert "orderaaa" in names1 and "orderbbb" in names1 + # Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1 + assert counts1["orderbbb"] >= counts1["orderaaa"] + # And with count_desc, 'orderbbb' appears earlier than 'orderaaa' + assert names1.index("orderbbb") < names1.index("orderaaa") + + # 2) name_asc: lexical order should flip the relative order + r2 = http.get( + api_base + "/api/tags", + params={"prefix": "order", "include_zero": "false", "order": "name_asc"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200, b2 + names2 = [t["name"] for t in b2["tags"]] + assert "orderaaa" in names2 and "orderbbb" in names2 + assert names2.index("orderaaa") < names2.index("orderbbb") + + # 3) invalid limit rejected (existing negative case retained) + r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120) + b3 = r3.json() + assert r3.status_code == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add with empty list + r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Remove with wrong type + r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_BODY" + + # metadata_filter provided as JSON array should be rejected (must be object) + r3 = http.get( + api_base + "/api/assets", + params={"metadata_filter": json.dumps([{"x": 1}])}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +def test_tags_prefix_treats_underscore_literal( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'prefix' for /api/tags must treat '_' literally, not as a wildcard.""" + base = f"pref_{uuid.uuid4().hex[:6]}" + tag_ok = f"{base}_ok" # should match prefix=f"{base}_" + tag_bad = f"{base}xok" # must NOT match if '_' is escaped + scope = f"tags-underscore-{uuid.uuid4().hex[:6]}" + + asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512)) + asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512)) + + r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120) + body = r.json() + assert r.status_code == 200, body + names = [t["name"] for t in body["tags"]] + assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'" + assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard" + assert body["total"] == 1 diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py new file mode 100644 index 000000000..137d7391a --- /dev/null +++ b/tests-unit/assets_test/test_uploads.py @@ -0,0 +1,281 @@ +import json +import uuid +from concurrent.futures import ThreadPoolExecutor + +import requests +import pytest + + +def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes): + name = "dup_a.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "dup"} + data = make_asset_bytes(name) + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a1 = r1.json() + assert r1.status_code == 201, a1 + assert a1["created_new"] is True + + # Second upload with the same data and name should return created_new == False and the same asset + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a2 = r2.json() + assert r2.status_code == 200, a2 + assert a2["created_new"] is False + assert a2["asset_hash"] == a1["asset_hash"] + assert a2["id"] == a1["id"] # old reference + + # Third upload with the same data but new name should return created_new == False and the new AssetReference + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r2.json() + assert r2.status_code == 200, a3 + assert a3["created_new"] is False + assert a3["asset_hash"] == a1["asset_hash"] + assert a3["id"] != a1["id"] # old reference + + +def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): + # Seed a small file first + name = "fastpath_seed.safetensors" + tags = ["models", "checkpoints", "unit-tests"] + meta = {} + files = {"file": (name, b"B" * 1024, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + h = b1["asset_hash"] + + # Now POST /api/assets with only hash and no file + files = [ + ("hash", (None, h)), + ("tags", (None, json.dumps(tags))), + ("name", (None, "fastpath_copy.safetensors")), + ("user_metadata", (None, json.dumps({"purpose": "copy"}))), + ] + r2 = http.post(api_base + "/api/assets", files=files, timeout=120) + b2 = r2.json() + assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +def test_upload_fastpath_with_known_hash_and_file( + http: requests.Session, api_base: str +): + # Seed + files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + h = b1["asset_hash"] + + # Send both file and hash of existing content -> server must drain file and create from hash (200) + files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")} + form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b2 = r2.json() + assert r2.status_code == 200, b2 + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str): + data = [ + ("tags", "models,checkpoints"), + ("tags", json.dumps(["unit-tests", "alpha"])), + ("name", "merge.safetensors"), + ("user_metadata", json.dumps({"u": 1})), + ] + files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")} + r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120) + created = r1.json() + assert r1.status_code in (200, 201), created + aid = created["id"] + + # Verify all tags are present on the resource + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200, detail + tags = set(detail["tags"]) + assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_concurrent_upload_identical_bytes_different_names( + root: str, + http: requests.Session, + api_base: str, + make_asset_bytes, +): + """ + Two concurrent uploads of identical bytes but different names. + Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + """ + scope = f"concupload-{uuid.uuid4().hex[:6]}" + name1, name2 = "cu_a.bin", "cu_b.bin" + data = make_asset_bytes("concurrent", 4096) + tags = [root, "unit-tests", scope] + + def _do_upload(args): + url, form_data, files_data = args + with requests.Session() as s: + return s.post(url, data=form_data, files=files_data, timeout=120) + + url = api_base + "/api/assets" + form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})} + files1 = {"file": (name1, data, "application/octet-stream")} + form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})} + files2 = {"file": (name2, data, "application/octet-stream")} + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)])) + r1, r2 = futures + + b1, b2 = r1.json(), r2.json() + assert r1.status_code in (200, 201), b1 + assert r2.status_code in (200, 201), b2 + assert b1["asset_hash"] == b2["asset_hash"] + assert b1["id"] != b2["id"] + + created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))]) + assert created_flags == [False, True] + + rl = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "sort": "name"}, + timeout=120, + ) + bl = rl.json() + assert rl.status_code == 200, bl + names = [a["name"] for a in bl.get("assets", [])] + assert set([name1, name2]).issubset(names) + + +def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str): + payload = { + "hash": "blake3:" + "0" * 64, + "name": "nonexistent.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120) + body = r.json() + assert r.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_upload_zero_byte_rejected(http: requests.Session, api_base: str): + files = {"file": ("empty.safetensors", b"", "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "EMPTY_UPLOAD" + + +def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str): + files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str): + files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_requires_multipart(http: requests.Session, api_base: str): + r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120) + body = r.json() + assert r.status_code == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): + files = [ + ("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))), + ("name", (None, "x.safetensors")), + ] + r = http.post(api_base + "/api/assets", files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +def test_upload_models_unknown_category(http: requests.Session, api_base: str): + files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert body["error"]["message"].startswith("unknown models category") + + +def test_upload_models_requires_category(http: requests.Session, api_base: str): + files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): + files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_duplicate_upload_same_display_name_does_not_clobber( + root: str, + http: requests.Session, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Two uploads use the same tags and the same display name but different bytes. + With hash-based filenames, they must NOT overwrite each other. Both assets + remain accessible and serve their original content. + """ + scope = f"dup-path-{uuid.uuid4().hex[:6]}" + display_name = "same_display.bin" + + d1 = make_asset_bytes(scope + "-v1", 1536) + d2 = make_asset_bytes(scope + "-v2", 2048) + tags = [root, "unit-tests", scope] + + first = asset_factory(display_name, tags, {}, d1) + second = asset_factory(display_name, tags, {}, d2) + + assert first["id"] != second["id"] + assert first["asset_hash"] != second["asset_hash"] # different content + assert first["name"] == second["name"] == display_name + + # Both must be independently retrievable + r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120) + b1 = r1.content + assert r1.status_code == 200 + assert b1 == d1 + r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120) + b2 = r2.content + assert r2.status_code == 200 + assert b2 == d2 diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 3a6790ee0..2355b8000 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,3 +2,4 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client +blake3 From aa6f7a83bb313f121ee032a8b21aa72ea32cfee1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 31 Jan 2026 17:05:11 -0800 Subject: [PATCH 015/215] Send is_input_list on v1 and v3 schema to frontend (#12188) --- comfy_api/latest/_io.py | 2 ++ server.py | 1 + 2 files changed, 3 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 78f77d4b2..eeea9781a 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1248,6 +1248,7 @@ class Hidden(str, Enum): class NodeInfoV1: input: dict=None input_order: dict[str, list[str]]=None + is_input_list: bool=None output: list[str]=None output_is_list: list[bool]=None output_name: list[str]=None @@ -1474,6 +1475,7 @@ class Schema: info = NodeInfoV1( input=input, input_order={key: list(value.keys()) for (key, value) in input.items()}, + is_input_list=self.is_input_list, output=output, output_is_list=output_is_list, output_name=output_name, diff --git a/server.py b/server.py index 2aee5cc06..2300393b2 100644 --- a/server.py +++ b/server.py @@ -656,6 +656,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} + info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False) info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] From 873de5f37a1f4c39c0203b2989e6736b0cb12b98 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 31 Jan 2026 18:11:11 -0800 Subject: [PATCH 016/215] KV cache implementation for using llama models for text generation. (#12195) --- comfy/text_encoders/llama.py | 91 +++++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3080a3e09..68ac1e804 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, Tuple import math from comfy.ldm.modules.attention import optimized_attention_for_device @@ -32,6 +32,7 @@ class Llama2Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Mistral3Small24BConfig: @@ -54,6 +55,7 @@ class Mistral3Small24BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_3BConfig: @@ -76,6 +78,7 @@ class Qwen25_3BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_06BConfig: @@ -98,6 +101,7 @@ class Qwen3_06BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_4BConfig: @@ -120,6 +124,7 @@ class Qwen3_4BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_8BConfig: @@ -142,6 +147,7 @@ class Qwen3_8BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Ovis25_2BConfig: @@ -164,6 +170,7 @@ class Ovis25_2BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_7BVLI_Config: @@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma2_2B_Config: @@ -209,6 +217,7 @@ class Gemma2_2B_Config: sliding_attention = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_4B_Config: @@ -232,6 +241,7 @@ class Gemma3_4B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_12B_Config: @@ -255,6 +265,7 @@ class Gemma3_12B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} mm_tokens_per_image = 256 @@ -356,6 +367,7 @@ class Attention(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): batch_size, seq_length, _ = hidden_states.shape xq = self.q_proj(hidden_states) @@ -373,11 +385,30 @@ class Attention(nn.Module): xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + present_key_value = None + if past_key_value is not None: + index = 0 + num_tokens = xk.shape[2] + if len(past_key_value) > 0: + past_key, past_value, index = past_key_value + if past_key.shape[2] >= (index + num_tokens): + past_key[:, :, index:index + xk.shape[2]] = xk + past_value[:, :, index:index + xv.shape[2]] = xv + xk = past_key[:, :, :index + xk.shape[2]] + xv = past_value[:, :, :index + xv.shape[2]] + present_key_value = (past_key, past_value, index + num_tokens) + else: + xk = torch.cat((past_key[:, :, :index], xk), dim=2) + xv = torch.cat((past_value[:, :, :index], xv), dim=2) + present_key_value = (xk, xv, index + num_tokens) + else: + present_key_value = (xk, xv, index + num_tokens) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) - return self.o_proj(output) + return self.o_proj(output), present_key_value class MLP(nn.Module): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): @@ -408,15 +439,17 @@ class TransformerBlock(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = residual + x @@ -426,7 +459,7 @@ class TransformerBlock(nn.Module): x = self.mlp(x) x = residual + x - return x + return x, present_key_value class TransformerBlockGemma2(nn.Module): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): @@ -451,6 +484,7 @@ class TransformerBlockGemma2(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): if self.transformer_type == 'gemma3': if self.sliding_attention: @@ -468,11 +502,12 @@ class TransformerBlockGemma2(nn.Module): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = self.post_attention_layernorm(x) @@ -485,7 +520,7 @@ class TransformerBlockGemma2(nn.Module): x = self.post_feedforward_layernorm(x) x = residual + x - return x + return x, present_key_value class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): @@ -516,9 +551,10 @@ class Llama2_(nn.Module): else: self.norm = None - # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + if config.lm_head: + self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): if embeds is not None: x = embeds else: @@ -527,8 +563,13 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = past_key_values[0][2] + if position_ids is None: - position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, @@ -539,14 +580,16 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None @@ -562,16 +605,27 @@ class Llama2_(nn.Module): elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output + next_key_values = [] for i, layer in enumerate(self.layers): if all_intermediate is not None: if only_layers is None or (i in only_layers): all_intermediate.append(x.unsqueeze(1).clone()) - x = layer( + + past_kv = None + if past_key_values is not None: + past_kv = past_key_values[i] if len(past_key_values) > 0 else [] + + x, current_kv = layer( x=x, attention_mask=mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_kv, ) + + if current_kv is not None: + next_key_values.append(current_kv) + if i == intermediate_output: intermediate = x.clone() @@ -588,7 +642,10 @@ class Llama2_(nn.Module): if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) - return x, intermediate + if len(next_key_values) > 0: + return x, intermediate, next_key_values + else: + return x, intermediate class Gemma3MultiModalProjector(torch.nn.Module): From f8acd9c402f41efc2c748503e4c5d9e0027965ab Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 31 Jan 2026 22:01:11 -0800 Subject: [PATCH 017/215] Reduce RAM usage, fix VRAM OOMs, and fix Windows shared memory spilling with adaptive model loading (#11845) --- comfy/audio_encoders/audio_encoders.py | 4 +- comfy/cli_args.py | 4 + comfy/clip_vision.py | 4 +- comfy/controlnet.py | 2 +- comfy/k_diffusion/sampling.py | 33 ++- comfy/ldm/hunyuan_video/upsampler.py | 4 +- comfy/memory_management.py | 81 +++++++ comfy/model_base.py | 15 +- comfy/model_management.py | 170 ++++++++++++-- comfy/model_patcher.py | 313 +++++++++++++++++++++++-- comfy/ops.py | 221 +++++++++++++++-- comfy/pinned_memory.py | 30 +++ comfy/samplers.py | 3 +- comfy/sd.py | 59 +++-- comfy/sd1_clip.py | 2 +- comfy/text_encoders/lt.py | 2 +- comfy/utils.py | 80 ++++++- comfy/windows.py | 52 ++++ comfy_extras/nodes_model_patch.py | 6 +- cuda_malloc.py | 12 +- execution.py | 16 +- main.py | 30 ++- requirements.txt | 1 + 23 files changed, 1030 insertions(+), 114 deletions(-) create mode 100644 comfy/memory_management.py create mode 100644 comfy/pinned_memory.py create mode 100644 comfy/windows.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..16998af94 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1716c3de7..63daca861 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -257,3 +258,6 @@ elif args.fast == []: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index b28bf636c..1691fca81 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -47,10 +47,10 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..9e1e704e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0949dee44..c0c51d51a 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,11 +1,12 @@ import math +import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange, tqdm +from tqdm.auto import trange as trange_, tqdm from . import utils from . import deis @@ -13,6 +14,36 @@ from . import sa_solver import comfy.model_patcher import comfy.model_sampling +import comfy.memory_management + + +def trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator is None: + return trange_(*args, **kwargs) + + pbar = trange_(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -109,10 +109,10 @@ class HunyuanVideo15SRModel(): self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..858bd4cc7 --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,81 @@ +import math +import torch +from typing import NamedTuple + +from comfy.quant_ops import QuantizedTensor + +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views + +aimdo_allocator = None diff --git a/comfy/model_base.py b/comfy/model_base.py index 66e52864d..85acdb66a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -149,6 +149,8 @@ class BaseModel(torch.nn.Module): self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 @@ -299,7 +301,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -307,7 +309,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -322,7 +324,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -330,10 +332,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -776,8 +775,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..758e718e8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,13 @@ import platform import weakref import gc import os +from contextlib import nullcontext +import comfy.memory_management +import comfy.utils +import comfy.quant_ops + +import comfy_aimdo.torch +import comfy_aimdo.model_vbar class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -578,9 +585,15 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: + import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 + def get_free_ram(): + return comfy.windows.get_free_ram() +else: + def get_free_ram(): + return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -592,7 +605,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -607,15 +620,23 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + memory_to_free = memory_required - get_free_memory(device) + ram_to_free = ram_required - get_free_ram() + + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) + if ram_to_free > 0: + logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -650,7 +671,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -676,19 +700,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_ram_required = {} for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we + #want to do. + #FIXME: This should subtract off the to_load current pin consumption. + total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: @@ -732,6 +762,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -749,6 +782,11 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + def cleanup_models(): to_delete = [] @@ -792,7 +830,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: return torch_dev else: return cpu_dev @@ -1051,6 +1089,53 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + torch.cuda.synchronize() + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() + STREAM_CAST_BUFFERS.clear() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1093,7 +1178,53 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] + + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + r.copy_(weight, non_blocking=non_blocking) + + if signature is not None: + weight._v_signature = signature + v_tensor.copy_(r) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1112,10 +1243,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1135,7 +1268,7 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): try: @@ -1557,8 +1690,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..57b53d8c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP - -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -212,6 +200,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size @@ -284,6 +296,9 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} @@ -611,14 +626,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -631,13 +646,15 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -654,7 +671,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, prio_comfy_cast_weights=False): loading = [] for n, m in self.model.named_modules(): params = [] @@ -681,7 +698,8 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () + loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -984,6 +1002,9 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) @@ -1317,10 +1338,10 @@ class ModelPatcher: key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device @@ -1355,7 +1376,249 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False): + if load_device is not None and comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + # x10. We dont know what model defined type casts we have in the vbar, but virtual address + # space is pretty free. This will cover someone casting an entire model from FP4 to FP32 + # with some left over. + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + self.partially_unload_ram(1e32) + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = "{}.{}".format(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return 0 + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return comfy.memory_management.vram_aligned_size(geometry) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = "{}.{}".format(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + geometry = weight + model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + weight_size = geometry.numel() * geometry.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + weight._model_dtype = model_dtype + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index e406ba7ed..c3a1825ce 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,10 +19,16 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float import comfy.rmsnorm import json +import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -72,7 +78,115 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + cast_dest = None + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + else: + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + if signature is None and pin is None: + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + else: + pin = None + + if pin is not None: + comfy.model_management.cast_to_gathered(xfer_source, pin) + xfer_source = [ pin ] + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -87,22 +201,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None - non_blocking = comfy.model_management.device_supports_non_blocking(device) + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +240,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) @@ -131,14 +262,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -149,6 +286,57 @@ class CastWeightBiasOp: class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + super().__init__(in_features, out_features, bias, device, dtype) + return + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k,v in state_dict.items(): + if k[prefix_len:] == "weight": + if not assign_to_params_buffers: + v = v.clone() + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + #Reconcile default construction of the weight if its missing. + if self.weight is None: + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"weight") + if self.bias is None and self.comfy_need_lazy_init_bias: + v = torch.zeros(self.out_features,) + self.bias = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"bias") + + def reset_parameters(self): return None @@ -655,8 +843,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -666,6 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input_shape = input.shape reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and @@ -684,7 +874,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input) + + output = self.forward_comfy_cast_weights(input, compute_dtype) # Reshape output back to 3D if input was 3D if reshaped_3d: diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000..0650e4d1a --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,30 @@ +import torch +import comfy.model_management +import comfy.memory_management + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) + size = comfy.memory_management.vram_aligned_size(params) + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: + module.pin_failed = True + return False + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + comfy.model_management.unpin_memory(module._pin) + del module._pin + return size diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107..8b9782956 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers @@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index f627f7d55..fd0ac85e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -228,8 +228,10 @@ class CLIP: self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -389,8 +391,18 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): @@ -765,12 +777,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) + model_management.archive_model_dtypes(self.first_stage_model) if device is None: device = model_management.vae_device() @@ -782,7 +789,18 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -897,7 +915,7 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -971,7 +989,7 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None @@ -1432,7 +1450,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1520,7 +1538,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1563,7 +1582,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1655,13 +1673,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) @@ -1692,9 +1711,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d4f22120b..9ecfc9c55 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e49161964..26573fb12 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -125,7 +125,7 @@ class LTXAVTEModel(torch.nn.Module): for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: - missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing_all.extend([f"{prefix}{k}" for k in missing]) unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) diff --git a/comfy/utils.py b/comfy/utils.py index d97d753e6..9e98eb176 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,9 +28,10 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram import json import time +import mmap MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -56,21 +57,67 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, + + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, +} + +def load_safetensors(ckpt): + f = open(ckpt, "rb") + mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + header_size = struct.unpack(" 0: message = e.args[0] @@ -1308,3 +1355,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF diff --git a/comfy/windows.py b/comfy/windows.py new file mode 100644 index 000000000..213dc481d --- /dev/null +++ b/comfy/windows.py @@ -0,0 +1,52 @@ +import ctypes +import logging +import psutil +from ctypes import wintypes + +import comfy_aimdo.control + +psapi = ctypes.WinDLL("psapi") +kernel32 = ctypes.WinDLL("kernel32") + +class PERFORMANCE_INFORMATION(ctypes.Structure): + _fields_ = [ + ("cb", wintypes.DWORD), + ("CommitTotal", ctypes.c_size_t), + ("CommitLimit", ctypes.c_size_t), + ("CommitPeak", ctypes.c_size_t), + ("PhysicalTotal", ctypes.c_size_t), + ("PhysicalAvailable", ctypes.c_size_t), + ("SystemCache", ctypes.c_size_t), + ("KernelTotal", ctypes.c_size_t), + ("KernelPaged", ctypes.c_size_t), + ("KernelNonpaged", ctypes.c_size_t), + ("PageSize", ctypes.c_size_t), + ("HandleCount", wintypes.DWORD), + ("ProcessCount", wintypes.DWORD), + ("ThreadCount", wintypes.DWORD), + ] + +def get_free_ram(): + #Windows is way too conservative and chalks recently used uncommitted model RAM + #as "in-use". So, calculate free RAM for the sake of general use as the greater of: + # + #1: What psutil says + #2: Total Memory - (Committed Memory - VRAM in use) + # + #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked + #commit charge for all VRAM used just incase it wants to page it all out. This just + #isn't realistic so "overcommit" on our calculations by just subtracting it off. + + pi = PERFORMANCE_INFORMATION() + pi.cb = ctypes.sizeof(pi) + + if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): + logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") + return psutil.virtual_memory().available + + committed = pi.CommitTotal * pi.PageSize + total = pi.PhysicalTotal * pi.PageSize + + return max(psutil.virtual_memory().available, + total - (committed - comfy_aimdo.control.get_total_vram_usage())) + diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 82c4754a3..176e6bc2f 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -267,9 +267,9 @@ class ModelPatchLoader: device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69..b2182df37 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,8 +1,10 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import subprocess +import comfy_aimdo.control + #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -85,8 +87,14 @@ if not args.cuda_malloc: except: pass +if enables_dynamic_vram() and comfy_aimdo.control.init(): + args.cuda_malloc = False + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" diff --git a/execution.py b/execution.py index 4b4f63c80..93fafc4a2 100644 --- a/execution.py +++ b/execution.py @@ -9,9 +9,11 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +517,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if allocator is not None: + comfy.model_management.reset_cast_buffers() + torch.cuda.synchronize() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) diff --git a/main.py b/main.py index 37b06c1fa..b8c951375 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import os import importlib.util import folder_paths import time -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger from app.assets.scanner import seed_assets import itertools @@ -173,6 +173,7 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + import comfy.utils import execution @@ -184,6 +185,33 @@ import comfyui_version import app.logger import hook_breaker_ac10a0 +import comfy.memory_management +import comfy.model_patcher + +import comfy_aimdo.control +import comfy_aimdo.torch + +if enables_dynamic_vram(): + if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + logging.info("DynamicVRAM support detected and enabled") + else: + logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) diff --git a/requirements.txt b/requirements.txt index 4ac94cb16..be0bc537e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 +comfy-aimdo>=0.1.6 requests #non essential dependencies: From 32621c6a11f6ce1c0dd46aeb08e35dd64ec804f0 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 31 Jan 2026 22:13:48 -0800 Subject: [PATCH 018/215] fix: improve error message when node type is missing (#12194) - Change error type from 'invalid_prompt' to 'missing_node_type' for frontend detection - Add extra_info with node_id, class_type, and node_title (from _meta.title) - Improve user-facing message: 'Node X not found. The custom node may not be installed.' --- execution.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index 93fafc4a2..3dbab82e6 100644 --- a/execution.py +++ b/execution.py @@ -1014,22 +1014,34 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ outputs = set() for x in prompt: if 'class_type' not in prompt[x]: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title') error = { - "type": "invalid_prompt", - "message": "Cannot execute because a node is missing the class_type property.", + "type": "missing_node_type", + "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": None, + "node_title": node_title + } } return (False, error, [], {}) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title', class_type) error = { - "type": "invalid_prompt", - "message": f"Cannot execute because node {class_type} does not exist.", + "type": "missing_node_type", + "message": f"Node '{node_title}' not found. The custom node may not be installed.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": class_type, + "node_title": node_title + } } return (False, error, [], {}) From 667a1b8878187f7a5c8955b272e33c16e4c46bb7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 31 Jan 2026 22:55:18 -0800 Subject: [PATCH 019/215] Fix some custom nodes breaking. (#12203) --- comfy/model_patcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57b53d8c5..c8e6f088f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -111,6 +111,10 @@ def move_weight_functions(m, device): memory += f.move_to(device=device) return memory +def string_to_seed(data): + logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils") + return comfy.utils.string_to_seed(data) + class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key From 361b9a82a3e2445fc22e352df187138b0c8f67fb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 08:42:32 -0800 Subject: [PATCH 020/215] fix pinning with model defined dtype (#12208) pinned memory was converted back to pinning the CPU side weight without any changes. Fix the pinner to use the CPU weight and not the model defined geometry. This will either save RAM or stop buffer overruns when the types mismatch. Fix the model defined weight caster to use the [ s.weight, s.bias ] interpretation, as xfer_dest might be the flattened pin now. Fix the detection of needing to cast to not be conditional on !pin. --- comfy/ops.py | 22 +++++++++++----------- comfy/pinned_memory.py | 3 +-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index c3a1825ce..53c5e4dc3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -96,16 +96,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu pin = comfy.pinned_memory.get_pin(s) if pin is not None: xfer_source = [ pin ] - else: - for data, geometry in zip([ s.weight, s.bias ], cast_geometry): - if data is None: - continue - if data.dtype != geometry.dtype: - cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) - xfer_dest = None - break + + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) @@ -132,7 +132,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.model_management.sync_stream(device, offload_stream) if cast_dest is not None: - for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 0650e4d1a..8acc327a7 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -11,8 +11,7 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event - params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) - size = comfy.memory_management.vram_aligned_size(params) + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) pin = torch.empty((size,), dtype=torch.uint8) if comfy.model_management.pin_memory(pin): module._pin = pin From 794d05bdb12be95f0e58dbb1728542ac4c7998aa Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:09:21 -0800 Subject: [PATCH 021/215] dynamic_vram: respect argument cast dtypes in non-comfy weights (#12209) This function has a dtype argument that allows the caller to set the dtype in the cast. TIL Some models override this on weight casts, which means its the highest priority. Priority scheme is: argument > model dtype > state dict dtype --- comfy/model_management.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 758e718e8..6b1166b94 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1202,27 +1202,36 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str assert r is None assert stream is None - r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) + + if dtype is None: + dtype = weight._model_dtype + + r = torch.empty_like(weight, dtype=dtype, device=device) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] - - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + weight._v_signature = signature + #Send it over + v_tensor.copy_(weight, non_blocking=non_blocking) #always take a deep copy even if _v is good, as we have no reasonable point to unpin #a non comfy weight r.copy_(v_tensor) comfy_aimdo.model_vbar.vbar_unpin(weight._v) return r + if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: + #Offloaded casting could skip this, however it would make the quantizations + #inconsistent between loaded and offloaded weights. So force the double casting + #that would happen in regular flow to make offload deterministic. + cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_buffer.copy_(weight, non_blocking=non_blocking) + weight = cast_buffer r.copy_(weight, non_blocking=non_blocking) - if signature is not None: - weight._v_signature = signature - v_tensor.copy_(r) - comfy_aimdo.model_vbar.vbar_unpin(weight._v) - return r if device is None or weight.device == device: From 2b5da3b72ebb8017661b0d58b548d4aeb53b7e42 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:09:55 -0800 Subject: [PATCH 022/215] dynamic_vram: silence pytorch buffer warning (#12210) This is log clutter and concerning to users. Its a false alarm. --- comfy/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 9e98eb176..c1b536833 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,6 +32,7 @@ from comfy.cli_args import args, enables_dynamic_vram import json import time import mmap +import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -85,7 +86,10 @@ def load_safetensors(ckpt): header_size = struct.unpack(" Date: Sun, 1 Feb 2026 17:10:15 -0800 Subject: [PATCH 023/215] requirements: bump comfy-aimdo to 0.1.7 (#12211) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index be0bc537e..3ca417dd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.6 +comfy-aimdo>=0.1.7 requests #non essential dependencies: From 021ba2071985f9e3b2984f396a238095c4a64832 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:12:52 -0800 Subject: [PATCH 024/215] Fix issue with parameters on root model object. (#12216) --- comfy/model_patcher.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c8e6f088f..b70c031bf 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -161,6 +161,11 @@ def get_key_weight(model, key): return weight, set_func, convert_func +def key_param_name_to_key(key, param): + if len(key) == 0: + return param + return "{}.{}".format(key, param) + class AutoPatcherEjector: def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model @@ -795,7 +800,7 @@ class ModelPatcher: continue for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) if comfy.model_management.is_device_cuda(device_to): @@ -811,7 +816,7 @@ class ModelPatcher: n = x[1] params = x[3] for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: @@ -917,7 +922,7 @@ class ModelPatcher: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) bk = self.backup.get(key, None) if bk is not None: if not lowvram_possible: @@ -968,7 +973,7 @@ class ModelPatcher: logging.debug("freed {}".format(n)) for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) self.model.model_lowvram = True @@ -1501,7 +1506,7 @@ class ModelPatcherDynamic(ModelPatcher): def setup_param(self, m, n, param_key): nonlocal num_patches - key = "{}.{}".format(n, param_key) + key = key_param_name_to_key(n, param_key) weight_function = [] @@ -1540,7 +1545,7 @@ class ModelPatcherDynamic(ModelPatcher): else: for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) weight.seed_key = key set_dirty(weight, dirty) From dd86b155210df9b34f479d70dad675aa782a30ef Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 00:51:09 -0800 Subject: [PATCH 025/215] Enable embeddings for some qwen 3 models. (#12218) --- comfy/text_encoders/anima.py | 2 +- comfy/text_encoders/flux.py | 6 +++--- comfy/text_encoders/z_image.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index 41f95bcb6..b6f58cb25 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -8,7 +8,7 @@ import torch class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index f67a5f805..1ae398789 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} @@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) class KleinTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index ad41bfb1e..33b7cf594 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -6,7 +6,7 @@ import os class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class ZImageTokenizer(sd1_clip.SD1Tokenizer): From 37f711d4a1d429e6b390b01729510525155385e1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:34:46 -0800 Subject: [PATCH 026/215] mm: Fix cast buffers with intel offloading (#12229) Intel has offloading support but there were some nvidia calls in the new cast buffer stuff. --- comfy/model_management.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 6b1166b94..2167f81bf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1112,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref): return None if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now - torch.cuda.synchronize() + synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer #FIXME: This doesn't work in Aimdo because mempool cant clear cache - torch.cuda.empty_cache() + soft_empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) STREAM_CAST_BUFFERS[offload_stream] = cast_buffer @@ -1132,9 +1132,7 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.empty_cache() + soft_empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1284,7 +1282,7 @@ def discard_cuda_async_error(): a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b - torch.cuda.synchronize() + synchronize() except torch.AcceleratorError: #Dump it! We already know about it from the synchronous return pass @@ -1688,6 +1686,12 @@ def lora_compute_dtype(device): LORA_COMPUTE_DTYPES[device] = dtype return dtype +def synchronize(): + if is_intel_xpu(): + torch.xpu.synchronize() + elif torch.cuda.is_available(): + torch.cuda.synchronize() + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: From de9ada6a4147f4626f903cd975a5d2c134af3915 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:35:20 -0800 Subject: [PATCH 027/215] Dynamic VRAM unloading fix (#12227) * mp: fix full dynamic unloading This was not unloading dynamic models when requesting a full unload via the unpatch() code path. This was ok, i your workflow was all dynamic models but fails with big VRAM leaks if you need to fully unload something for a regular ModelPatcher It also fices the "unload models" button. * mm: load models outside of Aimdo Mempool In dynamic_vram mode, escape the Aimdo mempool and load into the regular mempool. Use a dummy thread to do it. --- comfy/model_management.py | 29 ++++++++++++++++++++++------- comfy/model_patcher.py | 2 +- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2167f81bf..cd035f017 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -19,7 +19,8 @@ import psutil import logging from enum import Enum -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +import threading import torch import sys import platform @@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ soft_empty_cache() return unloaded_models -def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): +def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -746,8 +747,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_loaded_models.insert(0, loaded_model) return -def load_model_gpu(model): - return load_models_gpu([model]) +def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load): + with torch.inference_mode(): + load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + soft_empty_cache() + +def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + #Deliberately load models outside of the Aimdo mempool so they can be retained accross + #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are + #thread local. So exploit that to escape context + if enables_dynamic_vram(): + t = threading.Thread( + target=load_models_gpu_thread, + args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + ) + t.start() + t.join() + else: + load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, + minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) def loaded_models(only_currently_used=False): output = [] @@ -1717,9 +1735,6 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -#TODO: might be cleaner to put this somewhere else -import threading - class InterruptProcessingException(Exception): pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b70c031bf..cdf289395 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher): if unpatch_weights: self.partially_unload_ram(1e32) - self.partially_unload(None) + self.partially_unload(None, 1e32) def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above From c05a08ae66d8114c30f8d3606b9de3117cb61ef8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 16:52:07 -0800 Subject: [PATCH 028/215] Add back function. (#12234) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index cd035f017..72348258b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -767,6 +767,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) +def load_model_gpu(model): + return load_models_gpu([model]) + def loaded_models(only_currently_used=False): output = [] for m in current_loaded_models: From ba5bf3f1a801c6d4c24f015215314b6a653f8752 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 3 Feb 2026 05:17:59 +0200 Subject: [PATCH 029/215] [API Nodes] HitPaw API nodes (#12117) * feat(api-nodes): add HitPaw API nodes * remove face_soft_2x model as not working --------- Co-authored-by: Robin Huang --- comfy_api_nodes/apis/hitpaw.py | 51 ++++ comfy_api_nodes/nodes_hitpaw.py | 342 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 2 +- 3 files changed, 394 insertions(+), 1 deletion(-) create mode 100644 comfy_api_nodes/apis/hitpaw.py create mode 100644 comfy_api_nodes/nodes_hitpaw.py diff --git a/comfy_api_nodes/apis/hitpaw.py b/comfy_api_nodes/apis/hitpaw.py new file mode 100644 index 000000000..b23c5d9eb --- /dev/null +++ b/comfy_api_nodes/apis/hitpaw.py @@ -0,0 +1,51 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputVideoModel(TypedDict): + model: str + resolution: str + + +class ImageEnhanceTaskCreateRequest(BaseModel): + model_name: str = Field(...) + img_url: str = Field(...) + extension: str = Field(".png") + exif: bool = Field(False) + DPI: int | None = Field(None) + + +class VideoEnhanceTaskCreateRequest(BaseModel): + video_url: str = Field(...) + extension: str = Field(".mp4") + model_name: str | None = Field(...) + resolution: list[int] = Field(..., description="Target resolution [width, height]") + original_resolution: list[int] = Field(..., description="Original video resolution [width, height]") + + +class TaskCreateDataResponse(BaseModel): + job_id: str = Field(...) + consume_coins: int | None = Field(None) + + +class TaskStatusPollRequest(BaseModel): + job_id: str = Field(...) + + +class TaskCreateResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreateDataResponse | None = Field(None) + + +class TaskStatusDataResponse(BaseModel): + job_id: str = Field(...) + status: str = Field(...) + res_url: str = Field("") + + +class TaskStatusResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskStatusDataResponse = Field(...) diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py new file mode 100644 index 000000000..488080a74 --- /dev/null +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -0,0 +1,342 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.hitpaw import ( + ImageEnhanceTaskCreateRequest, + InputVideoModel, + TaskCreateDataResponse, + TaskCreateResponse, + TaskStatusPollRequest, + TaskStatusResponse, + VideoEnhanceTaskCreateRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + downscale_image_tensor, + get_image_dimensions, + poll_op, + sync_op, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_video_duration, +) + +VIDEO_MODELS_MODELS_MAP = { + "Portrait Restore Model (1x)": "portrait_restore_1x", + "Portrait Restore Model (2x)": "portrait_restore_2x", + "General Restore Model (1x)": "general_restore_1x", + "General Restore Model (2x)": "general_restore_2x", + "General Restore Model (4x)": "general_restore_4x", + "Ultra HD Model (2x)": "ultrahd_restore_2x", + "Generative Model (1x)": "generative_1x", +} + +# Resolution name to target dimension (shorter side) in pixels +RESOLUTION_TARGET_MAP = { + "720p": 720, + "1080p": 1080, + "2K/QHD": 1440, + "4K/UHD": 2160, + "8K": 4320, +} + +# Square (1:1) resolutions use standard square dimensions +RESOLUTION_SQUARE_MAP = { + "720p": 720, + "1080p": 1080, + "2K/QHD": 1440, + "4K/UHD": 2048, # DCI 4K square + "8K": 4096, # DCI 8K square +} + +# Models with limited resolution support (no 8K) +LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"} + +# Resolution options for different model types +RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"] +RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"] + +# Maximum output resolution in pixels +MAX_PIXELS_GENERATIVE = 32_000_000 +MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000 + + +class HitPawGeneralImageEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HitPawGeneralImageEnhance", + display_name="HitPaw General Image Enhance", + category="api node/image/HitPaw", + description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " + f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", + inputs=[ + IO.Combo.Input("model", options=["generative_portrait", "generative"]), + IO.Image.Input("image"), + IO.Combo.Input("upscale_factor", options=[1, 2, 4]), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed the limit.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := { + "generative_portrait": {"min": 0.02, "max": 0.06}, + "generative": {"min": 0.05, "max": 0.15} + }; + $price := $lookup($prices, widgets.model); + { + "type": "range_usd", + "min_usd": $price.min, + "max_usd": $price.max + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + upscale_factor: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + height, width = get_image_dimensions(image) + requested_scale = upscale_factor + output_pixels = height * width * requested_scale * requested_scale + if output_pixels > MAX_PIXELS_GENERATIVE: + if auto_downscale: + input_pixels = width * height + scale = 1 + max_input_pixels = MAX_PIXELS_GENERATIVE + + for candidate in [4, 2, 1]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= MAX_PIXELS_GENERATIVE: + scale = candidate + max_input_pixels = None + break + # Check if we can downscale input by at most 2x to fit + downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + upscale_factor = scale + else: + output_width = width * requested_scale + output_height = height * requested_scale + raise ValueError( + f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). " + f"Enable auto_downscale or use a smaller input image or a lower upscale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"), + response_model=TaskCreateResponse, + data=ImageEnhanceTaskCreateRequest( + model_name=f"{model}_{upscale_factor}x", + img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") + request_price = initial_res.data.consume_coins / 1000 + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), + data=TaskCreateDataResponse(job_id=initial_res.data.job_id), + response_model=TaskStatusResponse, + status_extractor=lambda x: x.data.status, + price_extractor=lambda x: request_price, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) + + +class HitPawVideoEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + model_options = [] + for model_name in VIDEO_MODELS_MODELS_MAP: + if model_name in LIMITED_RESOLUTION_MODELS: + resolutions = RESOLUTIONS_LIMITED + else: + resolutions = RESOLUTIONS_FULL + model_options.append( + IO.DynamicCombo.Option( + model_name, + [IO.Combo.Input("resolution", options=resolutions)], + ) + ) + + return IO.Schema( + node_id="HitPawVideoEnhance", + display_name="HitPaw Video Enhance", + category="api node/video/HitPaw", + description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " + "Prices shown are per second of video.", + inputs=[ + IO.DynamicCombo.Input("model", options=model_options), + IO.Video.Input("video"), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), + expr=""" + ( + $m := $lookup(widgets, "model"); + $res := $lookup(widgets, "model.resolution"); + $standard_model_prices := { + "original": {"min": 0.01, "max": 0.198}, + "720p": {"min": 0.01, "max": 0.06}, + "1080p": {"min": 0.015, "max": 0.09}, + "2k/qhd": {"min": 0.02, "max": 0.117}, + "4k/uhd": {"min": 0.025, "max": 0.152}, + "8k": {"min": 0.033, "max": 0.198} + }; + $ultra_hd_model_prices := { + "original": {"min": 0.015, "max": 0.264}, + "720p": {"min": 0.015, "max": 0.092}, + "1080p": {"min": 0.02, "max": 0.12}, + "2k/qhd": {"min": 0.026, "max": 0.156}, + "4k/uhd": {"min": 0.034, "max": 0.203}, + "8k": {"min": 0.044, "max": 0.264} + }; + $generative_model_prices := { + "original": {"min": 0.015, "max": 0.338}, + "720p": {"min": 0.008, "max": 0.090}, + "1080p": {"min": 0.05, "max": 0.15}, + "2k/qhd": {"min": 0.038, "max": 0.225}, + "4k/uhd": {"min": 0.056, "max": 0.338} + }; + $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices : + $contains($m, "generative") ? $generative_model_prices : + $standard_model_prices; + $price := $lookup($prices, $res); + { + "type": "range_usd", + "min_usd": $price.min, + "max_usd": $price.max, + "format": {"approximate": true, "suffix": "/second"} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: InputVideoModel, + video: Input.Video, + ) -> IO.NodeOutput: + validate_video_duration(video, min_duration=0.5, max_duration=60 * 60) + resolution = model["resolution"] + src_width, src_height = video.get_dimensions() + + if resolution == "original": + output_width = src_width + output_height = src_height + else: + if src_width == src_height: + target_size = RESOLUTION_SQUARE_MAP[resolution] + if target_size < src_width: + raise ValueError( + f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than " + f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'." + ) + output_width = target_size + output_height = target_size + else: + min_dimension = min(src_width, src_height) + target_size = RESOLUTION_TARGET_MAP[resolution] + if target_size < min_dimension: + raise ValueError( + f"Selected resolution {resolution} ({target_size}p) is smaller than " + f"the input video's shorter dimension ({min_dimension}p). " + f"Please select a higher resolution or 'original'." + ) + if src_width > src_height: + output_height = target_size + output_width = int(target_size * (src_width / src_height)) + else: + output_width = target_size + output_height = int(target_size * (src_height / src_width)) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"), + response_model=TaskCreateResponse, + data=VideoEnhanceTaskCreateRequest( + video_url=await upload_video_to_comfyapi(cls, video), + resolution=[output_width, output_height], + original_resolution=[src_width, src_height], + model_name=VIDEO_MODELS_MODELS_MAP[model["model"]], + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + request_price = initial_res.data.consume_coins / 1000 + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), + data=TaskStatusPollRequest(job_id=initial_res.data.job_id), + response_model=TaskStatusResponse, + status_extractor=lambda x: x.data.status, + price_extractor=lambda x: request_price, + poll_interval=10.0, + max_poll_attempts=320, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) + + +class HitPawExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + HitPawGeneralImageEnhance, + HitPawVideoEnhance, + ] + + +async def comfy_entrypoint() -> HitPawExtension: + return HitPawExtension() diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 3153f2b98..83d936ce1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -94,7 +94,7 @@ async def upload_image_to_comfyapi( *, mime_type: str | None = None, wait_label: str | None = "Uploading", - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, ) -> str: """Uploads a single image to ComfyUI API and returns its download URL.""" return ( From 3c1a1a2df82fc6c7aa20a3c8301d1c632e1a1d87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:06:18 -0800 Subject: [PATCH 030/215] Basic support for the ace step 1.5 model. (#12237) --- comfy/latent_formats.py | 4 + comfy/ldm/ace/ace_step15.py | 1093 ++++++++++++++++++++++++++++++++++ comfy/model_base.py | 42 ++ comfy/model_detection.py | 5 + comfy/sd.py | 21 +- comfy/sd1_clip.py | 2 + comfy/supported_models.py | 35 +- comfy/text_encoders/ace15.py | 218 +++++++ comfy/text_encoders/llama.py | 67 +++ comfy_extras/nodes_ace.py | 75 +++ comfy_extras/nodes_audio.py | 10 +- nodes.py | 2 +- 12 files changed, 1566 insertions(+), 8 deletions(-) create mode 100644 comfy/ldm/ace/ace_step15.py create mode 100644 comfy/text_encoders/ace15.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4b3a3798c..f59999af6 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -755,6 +755,10 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class ACEAudio15(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + class ChromaRadiance(LatentFormat): latent_channels = 3 spacial_downscale_ratio = 1 diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py new file mode 100644 index 000000000..d90549658 --- /dev/null +++ b/comfy/ldm/ace/ace_step15.py @@ -0,0 +1,1093 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import itertools +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management +from comfy.ldm.flux.layers import timestep_embedding + +def get_layer_class(operations, layer_name): + if operations is not None and hasattr(operations, layer_name): + return getattr(operations, layer_name) + return getattr(nn, layer_name) + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None): + super().__init__() + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len, x.device, x.dtype) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device), + self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device), + ) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device) + self.act1 = nn.SiLU() + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device) + self.in_channels = in_channels + self.act2 = nn.SiLU() + self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device) + self.scale = scale + + def forward(self, t, dtype=None): + t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale) + temb = self.linear_1(t_freq.to(dtype=dtype)) + temb = self.act1(temb) + temb = self.linear_2(temb) + timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1]) + return temb, timestep_proj + +class AceStepAttention(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rms_norm_eps=1e-6, + is_cross_attention=False, + sliding_window=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.is_cross_attention = is_cross_attention + self.sliding_window = sliding_window + + Linear = get_layer_class(operations, "Linear") + + self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device) + self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + position_embeddings=None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + query_states = query_states.transpose(1, 2) + + if self.is_cross_attention and encoder_hidden_states is not None: + bsz_enc, kv_len, _ = encoder_hidden_states.size() + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + + key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + else: + kv_len = q_len + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + n_rep = self.num_heads // self.num_kv_heads + if n_rep > 1: + key_states = key_states.repeat_interleave(n_rep, dim=1) + value_states = value_states.repeat_interleave(n_rep, dim=1) + + attn_bias = None + if self.sliding_window is not None and not self.is_cross_attention: + indices = torch.arange(q_len, device=query_states.device) + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + in_window = torch.abs(diff) <= self.sliding_window + + window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype) + min_value = torch.finfo(query_states.dtype).min + window_bias.masked_fill_(~in_window, min_value) + + window_bias = window_bias.unsqueeze(0).unsqueeze(0) + + if attn_bias is not None: + if attn_bias.dtype == torch.bool: + base_bias = torch.zeros_like(window_bias) + base_bias.masked_fill_(~attn_bias, min_value) + attn_bias = base_bias + window_bias + else: + attn_bias = attn_bias + window_bias + else: + attn_bias = window_bias + + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) + attn_output = self.o_proj(attn_output) + + return attn_output + +class AceStepDiTLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + layer_type="full_attention", + sliding_window=128, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self_attn_window = sliding_window if layer_type == "sliding_attention" else None + + self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, sliding_window=self_attn_window, + dtype=dtype, device=device, operations=operations + ) + + self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.cross_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=True, dtype=dtype, device=device, operations=operations + ) + + self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + temb, + encoder_hidden_states, + position_embeddings, + attention_mask=None, + encoder_attention_mask=None + ): + modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1) + + norm_hidden = self.self_attn_norm(hidden_states) + norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa + + attn_out = self.self_attn( + norm_hidden, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = hidden_states + attn_out * gate_msa + + norm_hidden = self.cross_attn_norm(hidden_states) + attn_out = self.cross_attn( + norm_hidden, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask + ) + hidden_states = hidden_states + attn_out + + norm_hidden = self.mlp_norm(hidden_states) + norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa + + mlp_out = self.mlp(norm_hidden) + hidden_states = hidden_states + mlp_out * c_gate_msa + + return hidden_states + +class AceStepEncoderLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, dtype=dtype, device=device, operations=operations + ) + self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + def forward(self, hidden_states, position_embeddings, attention_mask=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + +class AceStepLyricEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, inputs_embeds, attention_mask=None): + hidden_states = self.embed_tokens(inputs_embeds) + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + position_embeddings = (cos, sin) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + + hidden_states = self.norm(hidden_states) + return hidden_states + +class AceStepTimbreEncoder(nn.Module): + def __init__( + self, + timbre_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + + def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): + N, d = timbre_embs_packed.shape + device = timbre_embs_packed.device + B = N + counts = torch.bincount(refer_audio_order_mask, minlength=B) + max_count = counts.max().item() + + sorted_indices = torch.argsort( + refer_audio_order_mask * N + torch.arange(N, device=device), + stable=True + ) + sorted_batch_ids = refer_audio_order_mask[sorted_indices] + + positions = torch.arange(N, device=device) + batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) + positions_in_sorted = positions - batch_starts[sorted_batch_ids] + + inverse_indices = torch.empty_like(sorted_indices) + inverse_indices[sorted_indices] = torch.arange(N, device=device) + positions_in_batch = positions_in_sorted[inverse_indices] + + indices_2d = refer_audio_order_mask * max_count + positions_in_batch + one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype) + + timbre_embs_flat = one_hot.t() @ timbre_embs_packed + timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d) + + mask_flat = (one_hot.sum(dim=0) > 0).long() + new_mask = mask_flat.view(B, max_count) + return timbre_embs_unpack, new_mask + + def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None): + hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=(cos, sin), + attention_mask=attention_mask + ) + hidden_states = self.norm(hidden_states) + + flat_states = hidden_states[:, 0, :] + unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask) + return unpacked_embs, unpacked_mask + + +def pack_sequences(hidden1, hidden2, mask1, mask2): + hidden_cat = torch.cat([hidden1, hidden2], dim=1) + B, L, D = hidden_cat.shape + + if mask1 is not None and mask2 is not None: + mask_cat = torch.cat([mask1, mask2], dim=1) + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) + gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D) + hidden_sorted = torch.gather(hidden_cat, 1, gather_idx) + lengths = mask_cat.sum(dim=1) + new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) + else: + new_mask = None + hidden_sorted = hidden_cat + + return hidden_sorted, new_mask + +class AceStepConditionEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + timbre_hidden_dim, + hidden_size, + num_lyric_layers, + num_timbre_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.lyric_encoder = AceStepLyricEncoder( + text_hidden_dim=text_hidden_dim, + hidden_size=hidden_size, + num_layers=num_lyric_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + self.timbre_encoder = AceStepTimbreEncoder( + timbre_hidden_dim=timbre_hidden_dim, + hidden_size=hidden_size, + num_layers=num_timbre_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + def forward( + self, + text_hidden_states=None, + text_attention_mask=None, + lyric_hidden_states=None, + lyric_attention_mask=None, + refer_audio_acoustic_hidden_states_packed=None, + refer_audio_order_mask=None + ): + text_emb = self.text_projector(text_hidden_states) + + lyric_emb = self.lyric_encoder( + inputs_embeds=lyric_hidden_states, + attention_mask=lyric_attention_mask + ) + + timbre_emb, timbre_mask = self.timbre_encoder( + refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask + ) + + merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask) + final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask) + + return final_emb, final_mask + +# -------------------------------------------------------------------------------- +# Main Diffusion Model (DiT) +# -------------------------------------------------------------------------------- + +class AceStepDiTModel(nn.Module): + def __init__( + self, + in_channels, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + patch_size, + audio_acoustic_hidden_dim, + layer_types=None, + sliding_window=128, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.patch_size = patch_size + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + Conv1d = get_layer_class(operations, "Conv1d") + ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d") + Linear = get_layer_class(operations, "Linear") + + self.proj_in = nn.Sequential( + nn.Identity(), + Conv1d( + in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, + dtype=dtype, device=device)) + + self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + + if layer_types is None: + layer_types = ["full_attention"] * num_layers + + if len(layer_types) < num_layers: + layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers)) + + self.layers = nn.ModuleList([ + AceStepDiTLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + layer_type=layer_types[i], + sliding_window=sliding_window, + dtype=dtype, device=device, operations=operations + ) for i in range(num_layers) + ]) + + self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.proj_out = nn.Sequential( + nn.Identity(), + ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device) + ) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + timestep, + timestep_r, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + context_latents + ): + temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype) + temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype) + temb = temb_t + temb_r + timestep_proj = proj_t + proj_r + + x = torch.cat([context_latents, hidden_states], dim=-1) + original_seq_len = x.shape[1] + + pad_length = 0 + if x.shape[1] % self.patch_size != 0: + pad_length = self.patch_size - (x.shape[1] % self.patch_size) + x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0) + + x = x.transpose(1, 2) + x = self.proj_in(x) + x = x.transpose(1, 2) + + encoder_hidden_states = self.condition_embedder(encoder_hidden_states) + + seq_len = x.shape[1] + cos, sin = self.rotary_emb(x, seq_len=seq_len) + + for layer in self.layers: + x = layer( + hidden_states=x, + temb=timestep_proj, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=(cos, sin), + attention_mask=None, + encoder_attention_mask=None + ) + + shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + x = self.norm_out(x) * (1 + scale) + shift + + x = x.transpose(1, 2) + x = self.proj_out(x) + x = x.transpose(1, 2) + + x = x[:, :original_seq_len, :] + return x + + +class AttentionPooler(nn.Module): + def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, x): + B, T, P, D = x.shape + x = self.embed_tokens(x) + special = self.special_token.expand(B, T, 1, -1) + x = torch.cat([special, x], dim=2) + x = x.view(B * T, P + 1, D) + + cos, sin = self.rotary_emb(x, seq_len=P + 1) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + return x[:, 0, :].view(B, T, D) + + +class FSQ(nn.Module): + def __init__( + self, + levels, + dim=None, + device=None, + dtype=None, + operations=None + ): + super().__init__() + + _levels = torch.tensor(levels, dtype=torch.int32, device=device) + self.register_buffer('_levels', _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0) + self.register_buffer('_basis', _basis, persistent=False) + + self.codebook_dim = len(levels) + self.dim = dim if dim is not None else self.codebook_dim + + requires_projection = self.dim != self.codebook_dim + if requires_projection: + self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.codebook_size = self._levels.prod().item() + + indices = torch.arange(self.codebook_size, device=device) + implicit_codebook = self._indices_to_codes(indices) + + if dtype is not None: + implicit_codebook = implicit_codebook.to(dtype) + + self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) + + def bound(self, z): + levels_minus_1 = (self._levels - 1).to(z.dtype) + scale = 2. / levels_minus_1 + bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 + + zhat = bracket.floor() + bracket_ste = bracket + (zhat - bracket).detach() + + return scale * bracket_ste - 1. + + def _indices_to_codes(self, indices): + indices = indices.unsqueeze(-1) + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. + + def codes_to_indices(self, zhat): + zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1)) + return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32) + + def forward(self, z): + orig_dtype = z.dtype + z = self.project_in(z) + + codes = self.bound(z) + indices = self.codes_to_indices(codes) + + out = self.project_out(codes) + return out.to(orig_dtype), indices + + +class ResidualFSQ(nn.Module): + def __init__( + self, + levels, + num_quantizers, + dim=None, + bound_hard_clamp=True, + device=None, + dtype=None, + operations=None, + **kwargs + ): + super().__init__() + + codebook_dim = len(levels) + dim = dim if dim is not None else codebook_dim + + requires_projection = codebook_dim != dim + if requires_projection: + self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.layers = nn.ModuleList() + levels_tensor = torch.tensor(levels, device=device) + scales = [] + + for ind in range(num_quantizers): + scale_val = levels_tensor.float() ** -ind + scales.append(scale_val) + + self.layers.append(FSQ( + levels=levels, + dim=codebook_dim, + device=device, + dtype=dtype, + operations=operations + )) + + scales_tensor = torch.stack(scales) + if dtype is not None: + scales_tensor = scales_tensor.to(dtype) + self.register_buffer('scales', scales_tensor, persistent=False) + + if bound_hard_clamp: + val = 1 + (1 / (levels_tensor.float() - 1)) + if dtype is not None: + val = val.to(dtype) + self.register_buffer('soft_clamp_input_value', val, persistent=False) + + def get_output_from_indices(self, indices, dtype=torch.float32): + if indices.dim() == 2: + indices = indices.unsqueeze(-1) + + all_codes = [] + for i, layer in enumerate(self.layers): + idx = indices[..., i].long() + codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype)) + all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype)) + + codes_summed = torch.stack(all_codes).sum(dim=0) + return self.project_out(codes_summed) + + def forward(self, x): + x = self.project_in(x) + + if hasattr(self, 'soft_clamp_input_value'): + sc_val = self.soft_clamp_input_value.to(x.dtype) + x = (x / sc_val).tanh() * sc_val + + quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) + residual = x + all_indices = [] + + for layer, scale in zip(self.layers, self.scales): + scale = scale.to(residual.dtype) + + quantized, indices = layer(residual / scale) + quantized = quantized * scale + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + + quantized_out = self.project_out(quantized_out) + all_indices = torch.stack(all_indices, dim=-1) + + return quantized_out, all_indices + + +class AceStepAudioTokenizer(nn.Module): + def __init__( + self, + audio_acoustic_hidden_dim, + hidden_size, + pool_window_size, + fsq_dim, + fsq_levels, + fsq_input_num_quantizers, + num_layers, + head_dim, + rms_norm_eps, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device) + self.attention_pooler = AttentionPooler( + hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations + ) + self.pool_window_size = pool_window_size + self.fsq_dim = fsq_dim + self.quantizer = ResidualFSQ( + dim=fsq_dim, + levels=fsq_levels, + num_quantizers=fsq_input_num_quantizers, + bound_hard_clamp=True, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, hidden_states): + hidden_states = self.audio_acoustic_proj(hidden_states) + hidden_states = self.attention_pooler(hidden_states) + quantized, indices = self.quantizer(hidden_states) + return quantized, indices + + def tokenize(self, x): + B, T, D = x.shape + P = self.pool_window_size + + if T % P != 0: + pad = P - (T % P) + x = F.pad(x, (0, 0, 0, pad)) + T = x.shape[1] + + T_patch = T // P + x = x.view(B, T_patch, P, D) + + quantized, indices = self.forward(x) + return quantized, indices + + +class AudioTokenDetokenizer(nn.Module): + def __init__( + self, + hidden_size, + pool_window_size, + audio_acoustic_hidden_dim, + num_layers, + head_dim, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.pool_window_size = pool_window_size + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device)) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) + self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device) + + def forward(self, x): + B, T, D = x.shape + x = self.embed_tokens(x) + x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1) + x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype) + x = x.view(B * T, self.pool_window_size, D) + + cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + x = self.proj_out(x) + return x.view(B, T * self.pool_window_size, -1) + + +class AceStepConditionGenerationModel(nn.Module): + def __init__( + self, + in_channels=192, + hidden_size=2048, + text_hidden_dim=1024, + timbre_hidden_dim=64, + audio_acoustic_hidden_dim=64, + num_dit_layers=24, + num_lyric_layers=8, + num_timbre_layers=4, + num_tokenizer_layers=2, + num_heads=16, + num_kv_heads=8, + head_dim=128, + intermediate_size=6144, + patch_size=2, + pool_window_size=5, + rms_norm_eps=1e-06, + timestep_mu=-0.4, + timestep_sigma=1.0, + data_proportion=0.5, + sliding_window=128, + layer_types=None, + fsq_dim=2048, + fsq_levels=[8, 8, 8, 5, 5, 5], + fsq_input_num_quantizers=1, + audio_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + self.timestep_mu = timestep_mu + self.timestep_sigma = timestep_sigma + self.data_proportion = data_proportion + self.pool_window_size = pool_window_size + + if layer_types is None: + layer_types = [] + for i in range(num_dit_layers): + layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention") + + self.decoder = AceStepDiTModel( + in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim, + intermediate_size, patch_size, audio_acoustic_hidden_dim, + layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.encoder = AceStepConditionEncoder( + text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, + num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.tokenizer = AceStepAudioTokenizer( + audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.detokenizer = AudioTokenDetokenizer( + hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, + dtype=dtype, device=device, operations=operations + ) + self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + def prepare_condition( + self, + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, + precomputed_lm_hints_25Hz=None, + audio_codes=None + ): + encoder_hidden, encoder_mask = self.encoder( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask + ) + + if precomputed_lm_hints_25Hz is not None: + lm_hints = precomputed_lm_hints_25Hz + else: + if audio_codes is not None: + if audio_codes.shape[1] * 5 < src_latents.shape[1]: + audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) + lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) + else: + assert False + # TODO ? + + lm_hints = self.detokenizer(lm_hints_5Hz) + + lm_hints = lm_hints[:, :src_latents.shape[1], :] + if is_covers is None: + src_latents = lm_hints + else: + src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents) + + context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) + + return encoder_hidden, encoder_mask, context_latents + + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs): + text_attention_mask = None + lyric_attention_mask = None + refer_audio_order_mask = None + attention_mask = None + chunk_masks = None + is_covers = None + src_latents = None + precomputed_lm_hints_25Hz = None + lyric_hidden_states = lyric_embed + text_hidden_states = context + refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2) + + x = x.movedim(-1, -2) + + if refer_audio_order_mask is None: + refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) + + if src_latents is None and is_covers is None: + src_latents = x + + if chunk_masks is None: + chunk_masks = torch.ones_like(x) + + enc_hidden, enc_mask, context_latents = self.prepare_condition( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes + ) + + out = self.decoder(hidden_states=x, + timestep=timestep, + timestep_r=timestep, + attention_mask=attention_mask, + encoder_hidden_states=enc_hidden, + encoder_attention_mask=enc_mask, + context_latents=context_latents + ) + + return out.movedim(-1, -2) diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..89944548c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.ace.ace_step15 import comfy.model_management import comfy.patcher_extension @@ -1540,6 +1541,47 @@ class ACEStep(BaseModel): out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out +class ACEStep15(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + device = kwargs["device"] + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_lyrics = kwargs.get("conditioning_lyrics", None) + if cross_attn is not None: + out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics) + + refer_audio = kwargs.get("reference_audio_timbre_latents", None) + if refer_audio is None or len(refer_audio) == 0: + refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, + 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, + -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, + 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, + 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, + -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, + 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, + -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, + 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, + 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, + -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, + -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750) + else: + refer_audio = refer_audio[-1] + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + + return out + class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8cea16e50..e8ad725df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -655,6 +655,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["audio_model"] = "ace1.5" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..722c0c154 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie import comfy.text_encoders.anima +import comfy.text_encoders.ace15 import comfy.model_patcher import comfy.lora @@ -452,6 +453,8 @@ class VAE: self.extra_1d_channel = None self.crop_input = True + self.audio_sample_rate = 44100 + if config is None: if "decoder.mid.block_1.mix_factor" in sd: encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -549,14 +552,25 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: - self.first_stage_model = AudioOobleckVAE() + config = {} + param_key = None + if "decoder.layers.2.layers.1.weight_v" in sd: + param_key = "decoder.layers.2.layers.1.weight_v" + if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd: + param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1" + if param_key is not None: + if sd[param_key].shape[-1] == 12: + config["strides"] = [2, 4, 4, 6, 10] + self.audio_sample_rate = 48000 + + self.first_stage_model = AudioOobleckVAE(**config) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 self.pad_channel_value = "replicate" self.upscale_ratio = 2048 - self.downscale_ratio = 2048 + self.downscale_ratio = 2048 self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio @@ -1427,6 +1441,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data_jina = clip_data[0] tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) + elif clip_type == CLIPType.ACE: + clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9ecfc9c55..4c817d468 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass + elif isinstance(layer_idx, list): + self.layer = layer_idx elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..6b7d831cb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.anima +import comfy.text_encoders.ace15 from . import supported_models_base from . import latent_formats @@ -1596,6 +1597,38 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +class ACEStep15(supported_models_base.BASE): + unet_config = { + "audio_model": "ace1.5", + } + + unet_extra_config = { + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + latent_format = comfy.latent_formats.ACEAudio15 + + memory_usage_factor = 4.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.ACEStep15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py new file mode 100644 index 000000000..9070cb577 --- /dev/null +++ b/comfy/text_encoders/ace15.py @@ -0,0 +1,218 @@ +from .anima import Qwen3Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import torch +import math + + +def sample_manual_loop_no_classes( + model, + ids=None, + paddings=[], + execution_dtype=None, + cfg_scale: float = 2.0, + temperature: float = 0.85, + top_p: float = 0.9, + top_k: int = None, + seed: int = 1, + min_tokens: int = 1, + max_new_tokens: int = 2048, + audio_start_id: int = 151669, # The cutoff ID for audio codes + eos_token_id: int = 151645, +): + device = model.execution_device + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + + embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) + for i, t in enumerate(paddings): + attention_mask[i, :t] = 0 + attention_mask[i, t:] = 1 + + output_audio_codes = [] + past_key_values = [] + generator = torch.Generator(device=device) + generator.manual_seed(seed) + model_config = model.transformer.model.config + + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + + for step in range(max_new_tokens): + outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) + next_token_logits = model.transformer.logits(outputs[0])[:, -1] + past_key_values = outputs[2] + + cond_logits = next_token_logits[0:1] + uncond_logits = next_token_logits[1:2] + cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) + + if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + eos_score = cfg_logits[:, eos_token_id].clone() + + # Only generate audio tokens + cfg_logits[:, :audio_start_id] = float('-inf') + + if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + cfg_logits[:, eos_token_id] = eos_score + + if top_k is not None and top_k > 0: + top_k_vals, _ = torch.topk(cfg_logits, top_k) + min_val = top_k_vals[..., -1, None] + cfg_logits[cfg_logits < min_val] = float('-inf') + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + cfg_logits[indices_to_remove] = float('-inf') + + if temperature > 0: + cfg_logits = cfg_logits / temperature + next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1) + else: + next_token = torch.argmax(cfg_logits, dim=-1) + + token = next_token.item() + + if token == eos_token_id: + break + + embed, _, _, _ = model.process_tokens([[token]], device) + embeds = embed.repeat(2, 1, 1) + attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1) + + output_audio_codes.append(token - audio_start_id) + + return output_audio_codes + + +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0): + cfg_scale = 2.0 + + positive = [[token for token, _ in inner_list] for inner_list in positive] + negative = [[token for token, _ in inner_list] for inner_list in negative] + positive = positive[0] + negative = negative[0] + + neg_pad = 0 + if len(negative) < len(positive): + neg_pad = (len(positive) - len(negative)) + negative = [model.special_tokens["pad"]] * neg_pad + negative + + pos_pad = 0 + if len(negative) > len(positive): + pos_pad = (len(negative) - len(positive)) + positive = [model.special_tokens["pad"]] * pos_pad + positive + + paddings = [pos_pad, neg_pad] + return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + + +class ACE15Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + out = {} + lyrics = kwargs.get("lyrics", "") + bpm = kwargs.get("bpm", 120) + duration = kwargs.get("duration", 120) + keyscale = kwargs.get("keyscale", "C major") + timesignature = kwargs.get("timesignature", 2) + language = kwargs.get("language", "en") + seed = kwargs.get("seed", 0) + + duration = math.ceil(duration) + meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" + + meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration) + out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True) + out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True) + + out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) + out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) + out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} + return out + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_2B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class ACE15TEModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}): + super().__init__() + if dtype_llama is None: + dtype_llama = dtype + + self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) + self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options) + self.dtypes = set([dtype, dtype_llama]) + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_base = token_weight_pairs["qwen3_06b"] + token_weight_pairs_lyrics = token_weight_pairs["lyrics"] + + self.qwen3_06b.set_clip_options({"layer": None}) + base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base) + self.qwen3_06b.set_clip_options({"layer": [0]}) + lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) + + lm_metadata = token_weight_pairs["lm_metadata"] + audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + + return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} + + def set_clip_options(self, options): + self.qwen3_06b.set_clip_options(options) + self.qwen3_2b.set_clip_options(options) + + def reset_clip_options(self): + self.qwen3_06b.reset_clip_options() + self.qwen3_2b.reset_clip_options() + + def load_sd(self, sd): + if "model.layers.0.post_attention_layernorm.weight" in sd: + shape = sd["model.layers.0.post_attention_layernorm.weight"].shape + if shape[0] == 1024: + return self.qwen3_06b.load_sd(sd) + else: + return self.qwen3_2b.load_sd(sd) + + def memory_estimation_function(self, token_weight_pairs, device=None): + lm_metadata = token_weight_pairs["lm_metadata"] + constant = 0.4375 + if comfy.model_management.should_use_bf16(device): + constant *= 0.5 + + token_weight_pairs = token_weight_pairs.get("lm_prompt", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + num_tokens += lm_metadata['min_tokens'] + return num_tokens * constant * 1024 * 1024 + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ACE15TEModel_(ACE15TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options) + return ACE15TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 68ac1e804..d2324ffc5 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -103,6 +103,52 @@ class Qwen3_06BConfig: final_norm: bool = True lm_head: bool = False +@dataclass +class Qwen3_06B_ACE15_Config: + vocab_size: int = 151669 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + +@dataclass +class Qwen3_2B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -729,6 +775,27 @@ class Qwen3_06B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06B_ACE15_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_2B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + def logits(self, x): + return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None) + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 1409233c9..376584e5c 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -28,12 +28,39 @@ class TextEncodeAceStepAudio(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) return io.NodeOutput(conditioning) +class TextEncodeAceStepAudio15(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio1.5", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Int.Input("bpm", default=120, min=10, max=300), + io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + io.Combo.Input("timesignature", options=['2', '3', '4', '6']), + io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return io.NodeOutput(conditioning) + class EmptyAceStepLatentAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="EmptyAceStepLatentAudio", + display_name="Empty Ace Step 1.0 Latent Audio", category="latent/audio", inputs=[ io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), @@ -51,12 +78,60 @@ class EmptyAceStepLatentAudio(io.ComfyNode): return io.NodeOutput({"samples": latent, "type": "audio"}) +class EmptyAceStep15LatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStep1.5LatentAudio", + display_name="Empty Ace Step 1.5 Latent Audio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> io.NodeOutput: + length = round((seconds * 48000 / 1920)) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) + +class ReferenceTimbreAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReferenceTimbreAudio", + category="advanced/conditioning/audio", + is_experimental=True, + description="This node sets the reference audio for timbre (for ace step 1.5)", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: + if latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) + return io.NodeOutput(conditioning) + class AceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeAceStepAudio, EmptyAceStepLatentAudio, + TextEncodeAceStepAudio15, + EmptyAceStep15LatentAudio, + ReferenceTimbreAudio, ] async def comfy_entrypoint() -> AceExtension: diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 271b75fbd..bef723dce 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -82,13 +82,14 @@ class VAEEncodeAudio(IO.ComfyNode): @classmethod def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] - if 44100 != sample_rate: - waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return IO.NodeOutput({"samples":t}) + return IO.NodeOutput({"samples": t}) encode = execute # TODO: remove @@ -114,7 +115,8 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/nodes.py b/nodes.py index 1cb43d9e2..e11a8ed80 100644 --- a/nodes.py +++ b/nodes.py @@ -1001,7 +1001,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From be4345d1c9327c7a9292bc8f4153d35717020ef3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 3 Feb 2026 15:08:43 +0800 Subject: [PATCH 031/215] chore: update workflow templates to v0.8.31 (#12239) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3ca417dd8..0c401873a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.37.11 -comfyui-workflow-templates==0.8.27 +comfyui-workflow-templates==0.8.31 comfyui-embedded-docs==0.4.0 torch torchsde From 66e1b07402a84363b174c35d09acfcab0af7b0e5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 3 Feb 2026 02:20:59 -0500 Subject: [PATCH 032/215] ComfyUI v0.12.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b1ebaa115..bc6076b67 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.11.1" +__version__ = "0.12.0" diff --git a/pyproject.toml b/pyproject.toml index 042f124e4..28aa03067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.11.1" +version = "0.12.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From f5030e26fd72bdf19ab2b6463dc502fa0b2e34ba Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 01:09:30 -0800 Subject: [PATCH 033/215] Add progress bar to ace step. (#12242) --- comfy/text_encoders/ace15.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 9070cb577..75e77151d 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,6 +3,7 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math +import comfy.utils def sample_manual_loop_no_classes( @@ -42,6 +43,8 @@ def sample_manual_loop_no_classes( for x in range(model_config.num_hidden_layers): past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + progress_bar = comfy.utils.ProgressBar(max_new_tokens) + for step in range(max_new_tokens): outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) next_token_logits = model.transformer.logits(outputs[0])[:, -1] @@ -90,6 +93,7 @@ def sample_manual_loop_no_classes( attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1) output_audio_codes.append(token - audio_start_id) + progress_bar.update_absolute(step) return output_audio_codes From affe881354d1e41df9ac9394b2d8257faebfdba1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:07:04 -0800 Subject: [PATCH 034/215] Fix some issues with mac. (#12247) --- comfy/text_encoders/ace15.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 75e77151d..48c29fa9a 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -57,8 +57,9 @@ def sample_manual_loop_no_classes( if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: eos_score = cfg_logits[:, eos_token_id].clone() + remove_logit_value = torch.finfo(cfg_logits.dtype).min # Only generate audio tokens - cfg_logits[:, :audio_start_id] = float('-inf') + cfg_logits[:, :audio_start_id] = remove_logit_value if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: cfg_logits[:, eos_token_id] = eos_score @@ -66,7 +67,7 @@ def sample_manual_loop_no_classes( if top_k is not None and top_k > 0: top_k_vals, _ = torch.topk(cfg_logits, top_k) min_val = top_k_vals[..., -1, None] - cfg_logits[cfg_logits < min_val] = float('-inf') + cfg_logits[cfg_logits < min_val] = remove_logit_value if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) @@ -75,7 +76,7 @@ def sample_manual_loop_no_classes( sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - cfg_logits[indices_to_remove] = float('-inf') + cfg_logits[indices_to_remove] = remove_logit_value if temperature > 0: cfg_logits = cfg_logits / temperature From 223364743c35d6e1dc4f5ebc3796234d0f8484cf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:31:36 -0800 Subject: [PATCH 035/215] llama: cast logits as a comfy-weight (#12248) This is using a different layers weight with .to(). Change it to use the ops caster if the original layer is a comfy weight so that it picks up dynamic_vram and async_offload functionality in full. Co-authored-by: Rattus --- comfy/text_encoders/llama.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d2324ffc5..d1c628d20 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -6,6 +6,7 @@ import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management +import comfy.ops import comfy.ldm.common_dit import comfy.clip_model @@ -794,7 +795,19 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): self.dtype = dtype def logits(self, x): - return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None) + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): From 85fc35e8fa44c6174425acb4f9167792bcc903a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:19:39 -0800 Subject: [PATCH 036/215] Fix mac issue. (#12250) --- comfy/text_encoders/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d1c628d20..f4ac224c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -628,10 +628,10 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1) if mask is not None: mask += causal_mask else: From fb23935c1139875d47e7b6e239f657bae73bda5b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 3 Feb 2026 20:31:46 +0200 Subject: [PATCH 037/215] feat(comfy_api): add basic 3D Model file types (#12129) * feat(comfy_api): add basic 3D Model file types * update Tripo nodes to use File3DGLB * update Rodin3D nodes to use File3DGLB * address PR review feedback: - Rename File3D parameter 'path' to 'source' - Convert File3D.data property to get_data() - Make .glb extension check case-insensitive in nodes_rodin.py - Restrict SaveGLB node to only accept File3DGLB * Fixed a bug in the Meshy Rig and Animation nodes * Fix backward compatability --- comfy_api/latest/__init__.py | 3 +- comfy_api/latest/_io.py | 52 +++++++++- comfy_api/latest/_util/__init__.py | 3 +- comfy_api/latest/_util/geometry_types.py | 77 ++++++++++++++ comfy_api_nodes/apis/meshy.py | 5 + comfy_api_nodes/nodes_hunyuan3d.py | 45 ++++---- comfy_api_nodes/nodes_meshy.py | 116 ++++++++++++++------- comfy_api_nodes/nodes_rodin.py | 67 ++++++++---- comfy_api_nodes/nodes_tripo.py | 126 +++++++++++------------ comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/download_helpers.py | 38 ++++++- comfy_extras/nodes_hunyuan3d.py | 27 ++++- comfy_extras/nodes_load_3d.py | 26 ++++- 13 files changed, 427 insertions(+), 160 deletions(-) diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index b0fa14ff6..8542a1dbc 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -105,6 +105,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + File3D = File3D ComfyAPI = ComfyAPI_latest diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index eeea9781a..93cf482ca 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG +from ._util import MESH, VOXEL, SVG as _SVG, File3D class FolderType(str, Enum): @@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH + +@comfytype(io_type="FILE_3D") +class File3DAny(ComfyTypeIO): + """General 3D file type - accepts any supported 3D format.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_GLB") +class File3DGLB(ComfyTypeIO): + """GLB format 3D file - binary glTF, best for web and cross-platform.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_GLTF") +class File3DGLTF(ComfyTypeIO): + """GLTF format 3D file - JSON-based glTF with external resources.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_FBX") +class File3DFBX(ComfyTypeIO): + """FBX format 3D file - best for game engines and animation.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_OBJ") +class File3DOBJ(ComfyTypeIO): + """OBJ format 3D file - simple geometry format.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_STL") +class File3DSTL(ComfyTypeIO): + """STL format 3D file - best for 3D printing.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_USDZ") +class File3DUSDZ(ComfyTypeIO): + """USDZ format 3D file - Apple AR format.""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -2037,6 +2080,13 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "File3DAny", + "File3DGLB", + "File3DGLTF", + "File3DFBX", + "File3DOBJ", + "File3DSTL", + "File3DUSDZ", "Hooks", "HookKeyframes", "TimestepsRange", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 6313eb01b..115baf392 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH +from .geometry_types import VOXEL, MESH, File3D from .image_types import SVG __all__ = [ @@ -9,5 +9,6 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index 385122778..b586fceb3 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -1,3 +1,8 @@ +import shutil +from io import BytesIO +from pathlib import Path +from typing import IO + import torch @@ -10,3 +15,75 @@ class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): self.vertices = vertices self.faces = faces + + +class File3D: + """Class representing a 3D file from a file path or binary stream. + + Supports both disk-backed (file path) and memory-backed (BytesIO) storage. + """ + + def __init__(self, source: str | IO[bytes], file_format: str = ""): + self._source = source + self._format = file_format or self._infer_format() + + def _infer_format(self) -> str: + if isinstance(self._source, str): + return Path(self._source).suffix.lstrip(".").lower() + return "" + + @property + def format(self) -> str: + return self._format + + @format.setter + def format(self, value: str) -> None: + self._format = value.lstrip(".").lower() if value else "" + + @property + def is_disk_backed(self) -> bool: + return isinstance(self._source, str) + + def get_source(self) -> str | IO[bytes]: + if isinstance(self._source, str): + return self._source + if hasattr(self._source, "seek"): + self._source.seek(0) + return self._source + + def get_data(self) -> BytesIO: + if isinstance(self._source, str): + with open(self._source, "rb") as f: + result = BytesIO(f.read()) + return result + if hasattr(self._source, "seek"): + self._source.seek(0) + if isinstance(self._source, BytesIO): + return self._source + return BytesIO(self._source.read()) + + def save_to(self, path: str) -> str: + dest = Path(path) + dest.parent.mkdir(parents=True, exist_ok=True) + + if isinstance(self._source, str): + if Path(self._source).resolve() != dest.resolve(): + shutil.copy2(self._source, dest) + else: + if hasattr(self._source, "seek"): + self._source.seek(0) + with open(dest, "wb") as f: + f.write(self._source.read()) + return str(dest) + + def get_bytes(self) -> bytes: + if isinstance(self._source, str): + return Path(self._source).read_bytes() + if hasattr(self._source, "seek"): + self._source.seek(0) + return self._source.read() + + def __repr__(self) -> str: + if isinstance(self._source, str): + return f"File3D(source={self._source!r}, format={self._format!r})" + return f"File3D(, format={self._format!r})" diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py index be46d0d58..7d72e6e91 100644 --- a/comfy_api_nodes/apis/meshy.py +++ b/comfy_api_nodes/apis/meshy.py @@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel): class MeshyModelsUrls(BaseModel): glb: str = Field("") + fbx: str = Field("") + usdz: str = Field("") + obj: str = Field("") class MeshyRiggedModelsUrls(BaseModel): rigged_character_glb_url: str = Field("") + rigged_character_fbx_url: str = Field("") class MeshyAnimatedModelsUrls(BaseModel): animation_glb_url: str = Field("") + animation_fbx_url: str = Field("") class MeshyResultTextureUrls(BaseModel): diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index b3a736643..813a7c809 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -1,5 +1,3 @@ -import os - from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input @@ -14,7 +12,7 @@ from comfy_api_nodes.apis.hunyuan3d import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_to_bytesio, + download_url_to_file_3d, downscale_image_tensor_by_max_side, poll_op, sync_op, @@ -22,14 +20,13 @@ from comfy_api_nodes.util import ( validate_image_dimensions, validate_string, ) -from folder_paths import get_output_directory -def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D: +def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None: for i in response_objs: - if i.Type.lower() == "glb": + if i.Type.lower() == file_type.lower(): return i - raise ValueError("No GLB file found in response. Please report this to the developers.") + return None class TencentTextToModelNode(IO.ComfyNode): @@ -74,7 +71,9 @@ class TencentTextToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -124,19 +123,20 @@ class TencentTextToModelNode(IO.ComfyNode): ) if response.Error: raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + task_id = response.JobId result = await poll_op( cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), + data=To3DProTaskQueryRequest(JobId=task_id), response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - model_file = f"hunyuan_model_{response.JobId}.glb" - await download_url_to_bytesio( - get_glb_obj_from_response(result.ResultFile3Ds).Url, - os.path.join(get_output_directory(), model_file), + glb_result = get_file_from_response(result.ResultFile3Ds, "glb") + obj_result = get_file_from_response(result.ResultFile3Ds, "obj") + file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None + return IO.NodeOutput( + file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None ) - return IO.NodeOutput(model_file) class TencentImageToModelNode(IO.ComfyNode): @@ -184,7 +184,9 @@ class TencentImageToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -269,19 +271,20 @@ class TencentImageToModelNode(IO.ComfyNode): ) if response.Error: raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + task_id = response.JobId result = await poll_op( cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), + data=To3DProTaskQueryRequest(JobId=task_id), response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - model_file = f"hunyuan_model_{response.JobId}.glb" - await download_url_to_bytesio( - get_glb_obj_from_response(result.ResultFile3Ds).Url, - os.path.join(get_output_directory(), model_file), + glb_result = get_file_from_response(result.ResultFile3Ds, "glb") + obj_result = get_file_from_response(result.ResultFile3Ds, "obj") + file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None + return IO.NodeOutput( + file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None ) - return IO.NodeOutput(model_file) class TencentHunyuan3DExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 740607983..65f6f0d2d 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -1,5 +1,3 @@ -import os - from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input @@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_to_bytesio, + download_url_to_file_3d, poll_op, sync_op, upload_images_to_comfyapi, validate_string, ) -from folder_paths import get_output_directory class MeshyTextToModelNode(IO.ComfyNode): @@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode): seed=seed, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyRefineNode(IO.ComfyNode): @@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode): ai_model=model, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyImageToModelNode(IO.ComfyNode): @@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode): seed=seed, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyMultiImageToModelNode(IO.ComfyNode): @@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): seed=seed, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyRigModelNode(IO.ComfyNode): @@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode): texture_image_url=texture_image_url, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"), response_model=MeshyRiggedResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio( - result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id), + await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id), ) - return IO.NodeOutput(model_file, response.result) class MeshyAnimateModelNode(IO.ComfyNode): @@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode): action_id=action_id, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"), response_model=MeshyAnimationResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id), + await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id), + ) class MeshyTextureNode(IO.ComfyNode): @@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode): image_style_url=image_style_url, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 3ffdc8b90..f9cff121f 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -10,7 +10,6 @@ import folder_paths as comfy_paths import os import logging import math -from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image @@ -28,8 +27,9 @@ from comfy_api_nodes.util import ( poll_op, ApiEndpoint, download_url_to_bytesio, + download_url_to_file_3d, ) -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Types COMMON_PARAMETERS = [ @@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" -def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: +def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None: if not response.jobs: return None completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) @@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D ) -async def download_files(url_list, task_uuid: str): +async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]: result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None + file_3d = None + for i in url_list.list: file_path = os.path.join(save_path, i.name) - if file_path.endswith(".glb"): + if i.name.lower().endswith(".glb"): model_file_path = os.path.join(result_folder_name, i.name) - await download_url_to_bytesio(i.url, file_path) - return model_file_path + file_3d = await download_url_to_file_3d(i.url, "glb") + # Save to disk for backward compatibility + with open(file_path, "wb") as f: + f.write(file_3d.get_bytes()) + else: + await download_url_to_bytesio(i.url, file_path) + + return model_file_path, file_3d class Rodin3D_Regular(IO.ComfyNode): @@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Detail(IO.ComfyNode): @@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Smooth(IO.ComfyNode): @@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Sketch(IO.ComfyNode): @@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode): optional=True, ), ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Gen2(IO.ComfyNode): @@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode): ), IO.Boolean.Input("TAPose", default=False), ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3DExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 5abf27b4d..67c7f59fc 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,10 +1,6 @@ -import os -from typing import Optional - -import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, @@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_as_bytesio, + download_url_to_file_3d, poll_op, sync_op, upload_images_to_comfyapi, ) -from folder_paths import get_output_directory def get_model_url_from_response(response: TripoTaskResponse) -> str: @@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: async def poll_until_finished( node_cls: type[IO.ComfyNode], response: TripoTaskResponse, - average_duration: Optional[int] = None, + average_duration: int | None = None, ) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: @@ -69,12 +64,8 @@ async def poll_until_finished( ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = await download_url_as_bytesio(url) - # Save the downloaded model file - model_file = f"tripo_model_{task_id}.glb" - with open(os.path.join(get_output_directory(), model_file), "wb") as f: - f.write(bytesio.getvalue()) - return IO.NodeOutput(model_file, task_id) + file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id) + return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb) raise RuntimeError(f"Failed to generate mesh: {response_poll}") @@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode): async def execute( cls, prompt: str, - negative_prompt: Optional[str] = None, + negative_prompt: str | None = None, model_version=None, - style: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - image_seed: Optional[int] = None, - model_seed: Optional[int] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + style: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + image_seed: int | None = None, + model_seed: int | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: @@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode): IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - model_version: Optional[str] = None, - style: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - model_seed: Optional[int] = None, + image: Input.Image, + model_version: str | None = None, + style: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + model_seed: int | None = None, orientation=None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + texture_alignment: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: @@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode): IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - image_left: Optional[torch.Tensor] = None, - image_back: Optional[torch.Tensor] = None, - image_right: Optional[torch.Tensor] = None, - model_version: Optional[str] = None, - orientation: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - model_seed: Optional[int] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + image: Input.Image, + image_left: Input.Image | None = None, + image_back: Input.Image | None = None, + image_right: Input.Image | None = None, + model_version: str | None = None, + orientation: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + model_seed: int | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + texture_alignment: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") @@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode): async def execute( cls, model_task_id, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, + texture: bool | None = None, + pbr: bool | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + texture_alignment: str | None = None, ) -> IO.NodeOutput: response = await sync_op( cls, @@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode): IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode): category="api node/3d/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index c3c9ff4bf..18b020eef 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -28,6 +28,7 @@ from .conversions import ( from .download_helpers import ( download_url_as_bytesio, download_url_to_bytesio, + download_url_to_file_3d, download_url_to_image_tensor, download_url_to_video_output, ) @@ -69,6 +70,7 @@ __all__ = [ # Download helpers "download_url_as_bytesio", "download_url_to_bytesio", + "download_url_to_file_3d", "download_url_to_image_tensor", "download_url_to_video_output", # Conversions diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 4668d14a9..78bcf1fa1 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -11,7 +11,8 @@ import torch from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.latest import IO as COMFY_IO -from comfy_api.latest import InputImpl +from comfy_api.latest import InputImpl, Types +from folder_paths import get_output_directory from . import request_logger from ._helpers import ( @@ -261,3 +262,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str: except Exception: slug = "download" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +async def download_url_to_file_3d( + url: str, + file_format: str, + *, + task_id: str | None = None, + timeout: float | None = None, + max_retries: int = 5, + cls: type[COMFY_IO.ComfyNode] = None, +) -> Types.File3D: + """Downloads a 3D model file from a URL into memory as BytesIO. + + If task_id is provided, also writes the file to disk in the output directory + for backward compatibility with the old save-to-disk behavior. + """ + file_format = file_format.lstrip(".").lower() + data = BytesIO() + await download_url_to_bytesio( + url, + data, + timeout=timeout, + max_retries=max_retries, + cls=cls, + ) + + if task_id is not None: + # This is only for backward compatability with current behavior when every 3D node is output node + # All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results + output_dir = Path(get_output_directory()) + output_path = output_dir / f"{task_id}.{file_format}" + output_path.write_bytes(data.getvalue()) + data.seek(0) + + return Types.File3D(source=data, file_format=file_format) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 5bb5df48e..eda1639ab 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -622,14 +622,20 @@ class SaveGLB(IO.ComfyNode): category="3d", is_output_node=True, inputs=[ - IO.Mesh.Input("mesh"), + IO.MultiType.Input( + IO.Mesh.Input("mesh"), + types=[ + IO.File3DGLB, + ], + tooltip="Mesh or GLB file to save", + ), IO.String.Input("filename_prefix", default="mesh/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] ) @classmethod - def execute(cls, mesh, filename_prefix) -> IO.NodeOutput: + def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results = [] @@ -641,15 +647,26 @@ class SaveGLB(IO.ComfyNode): for x in cls.hidden.extra_pnginfo: metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) - for i in range(mesh.vertices.shape[0]): + if isinstance(mesh, Types.File3D): + # Handle File3D input - save BytesIO data to output folder f = f"{filename}_{counter:05}_.glb" - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + mesh.save_to(os.path.join(full_output_folder, f)) results.append({ "filename": f, "subfolder": subfolder, "type": "output" }) - counter += 1 + else: + # Handle Mesh input - save vertices and faces as GLB + for i in range(mesh.vertices.shape[0]): + f = f"{filename}_{counter:05}_.glb" + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + results.append({ + "filename": f, + "subfolder": subfolder, + "type": "output" + }) + counter += 1 return IO.NodeOutput(ui={"3d": results}) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 4b8d950ae..f29510488 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -1,9 +1,10 @@ import nodes import folder_paths import os +import uuid from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, InputImpl, UI +from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types from pathlib import Path @@ -81,7 +82,19 @@ class Preview3D(IO.ComfyNode): is_experimental=True, is_output_node=True, inputs=[ - IO.String.Input("model_file", default="", multiline=False), + IO.MultiType.Input( + IO.String.Input("model_file", default="", multiline=False), + types=[ + IO.File3DGLB, + IO.File3DGLTF, + IO.File3DFBX, + IO.File3DOBJ, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, + ], + tooltip="3D model file or path string", + ), IO.Load3DCamera.Input("camera_info", optional=True), IO.Image.Input("bg_image", optional=True), ], @@ -89,10 +102,15 @@ class Preview3D(IO.ComfyNode): ) @classmethod - def execute(cls, model_file, **kwargs) -> IO.NodeOutput: + def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput: + if isinstance(model_file, Types.File3D): + filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}" + model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + else: + filename = model_file camera_info = kwargs.get("camera_info", None) bg_image = kwargs.get("bg_image", None) - return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) + return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image)) process = execute # TODO: remove From ab1050bec3ecb89d1bc06c13e67626805bd8e1ee Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:54:23 -0800 Subject: [PATCH 038/215] Support ace step 1.5 base model loras. (#12252) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 7b31d055c..44030bcab 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}): key_map["{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.ACEStep15): + for k in sdk: + if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model.decoder."):-len(".weight")] + key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras + return key_map From b8315e66cbcf675d683ed83892314cbd3a3d1bf7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 11:40:45 -0800 Subject: [PATCH 039/215] Fix tiled vae for ace step 1.5 (#12253) --- comfy/sd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 722c0c154..aa93e101d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -554,6 +554,8 @@ class VAE: elif "decoder.layers.1.layers.0.beta" in sd: config = {} param_key = None + self.upscale_ratio = 2048 + self.downscale_ratio = 2048 if "decoder.layers.2.layers.1.weight_v" in sd: param_key = "decoder.layers.2.layers.1.weight_v" if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd: @@ -562,6 +564,8 @@ class VAE: if sd[param_key].shape[-1] == 12: config["strides"] = [2, 4, 4, 6, 10] self.audio_sample_rate = 48000 + self.upscale_ratio = 1920 + self.downscale_ratio = 1920 self.first_stage_model = AudioOobleckVAE(**config) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) @@ -569,8 +573,6 @@ class VAE: self.latent_channels = 64 self.output_channels = 2 self.pad_channel_value = "replicate" - self.upscale_ratio = 2048 - self.downscale_ratio = 2048 self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio @@ -870,7 +872,7 @@ class VAE: / 3.0) return output - def decode_tiled_1d(self, samples, tile_x=128, overlap=32): + def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() else: From 3be017516643d9a8d8ed85904512ed4b240ef8c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 3 Feb 2026 14:13:18 -0500 Subject: [PATCH 040/215] ComfyUI v0.12.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index bc6076b67..2e2c12ced 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.12.0" +__version__ = "0.12.1" diff --git a/pyproject.toml b/pyproject.toml index 28aa03067..c21ee03f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.12.0" +version = "0.12.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From fe2511468d6b3f7a1581a711ae480f3313f2b39d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:01:38 -0800 Subject: [PATCH 041/215] Support the 4B ace step 1.5 lm model. (#12257) Can be used as an alternative to the 1.7B --- comfy/sd.py | 7 +++- comfy/supported_models.py | 12 +++++- comfy/text_encoders/ace15.py | 42 ++++++++++++++++----- comfy/text_encoders/llama.py | 72 ++++++++++++++++++++++++++---------- 4 files changed, 101 insertions(+), 32 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index aa93e101d..bc63d6ced 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1444,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) elif clip_type == CLIPType.ACE: - clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data)) + te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])] + if TEModel.QWEN3_4B in te_models: + model_type = "qwen3_4b" + else: + model_type = "qwen3_2b" + clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6b7d831cb..77264ed28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] - hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect)) + detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if "dtype_llama" in detect_2b: + detect = detect_2b + detect["lm_model"] = "qwen3_2b" + elif "dtype_llama" in detect_4b: + detect = detect_4b + detect["lm_model"] = "qwen3_4b" + + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 48c29fa9a..73d710671 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -162,14 +162,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) +class Qwen3_4B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + class ACE15TEModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}): super().__init__() if dtype_llama is None: dtype_llama = dtype + model = None + self.constant = 0.4375 + if lm_model == "qwen3_4b": + model = Qwen3_4B_ACE15 + self.constant = 0.5625 + elif lm_model == "qwen3_2b": + model = Qwen3_2B_ACE15 + + self.lm_model = lm_model self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) - self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options) + if model is not None: + setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options)) + self.dtypes = set([dtype, dtype_llama]) def encode_token_weights(self, token_weight_pairs): @@ -182,17 +202,21 @@ class ACE15TEModel(torch.nn.Module): lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) lm_metadata = token_weight_pairs["lm_metadata"] - audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} def set_clip_options(self, options): self.qwen3_06b.set_clip_options(options) - self.qwen3_2b.set_clip_options(options) + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.set_clip_options(options) def reset_clip_options(self): self.qwen3_06b.reset_clip_options() - self.qwen3_2b.reset_clip_options() + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.reset_clip_options() def load_sd(self, sd): if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -200,11 +224,11 @@ class ACE15TEModel(torch.nn.Module): if shape[0] == 1024: return self.qwen3_06b.load_sd(sd) else: - return self.qwen3_2b.load_sd(sd) + return getattr(self, self.lm_model).load_sd(sd) def memory_estimation_function(self, token_weight_pairs, device=None): lm_metadata = token_weight_pairs["lm_metadata"] - constant = 0.4375 + constant = self.constant if comfy.model_management.should_use_bf16(device): constant *= 0.5 @@ -213,11 +237,11 @@ class ACE15TEModel(torch.nn.Module): num_tokens += lm_metadata['min_tokens'] return num_tokens * constant * 1024 * 1024 -def te(dtype_llama=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): class ACE15TEModel_(ACE15TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["llama_quantization_metadata"] = llama_quantization_metadata - super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options) + super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options) return ACE15TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index f4ac224c6..3afd094d1 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -150,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config: final_norm: bool = True lm_head: bool = False +@dataclass +class Qwen3_4B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -739,6 +762,21 @@ class BaseLlama: def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) +class BaseQwen3: + def logits(self, x): + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Llama2(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): @@ -767,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B(BaseLlama, torch.nn.Module): +class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06BConfig(**config_dict) @@ -776,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): +class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06B_ACE15_Config(**config_dict) @@ -785,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): +class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_2B_ACE15_lm_Config(**config_dict) @@ -794,22 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype - def logits(self, x): - input = x[:, -1:] - module = self.model.embed_tokens - - offload_stream = None - if module.comfy_cast_weights: - weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) - else: - weight = self.model.embed_tokens.weight.to(x) - - x = torch.nn.functional.linear(input, weight, None) - - comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) - return x - -class Qwen3_4B(BaseLlama, torch.nn.Module): +class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_4BConfig(**config_dict) @@ -818,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_8B(BaseLlama, torch.nn.Module): +class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_8BConfig(**config_dict) From 855849c6588180fec88186127aae1a3299387fa6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Feb 2026 18:39:19 -0800 Subject: [PATCH 042/215] mm: Remove Aimdo exemption for empty_cache (#12260) Its more important to get the torch caching allocator GC up and running than supporting the pyt2.7 bug. Switch it on. Defeature dynamic_vram + pyt2.7. --- comfy/model_management.py | 8 +++----- main.py | 7 +++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 72348258b..b6291f340 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/main.py b/main.py index b8c951375..92d705b4d 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,10 @@ import comfy_aimdo.control import comfy_aimdo.torch if enables_dynamic_vram(): - if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if comfy.model_management.torch_version_numeric < (2, 8): + logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': @@ -208,7 +211,7 @@ if enables_dynamic_vram(): comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() logging.info("DynamicVRAM support detected and enabled") else: - logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") comfy.memory_management.aimdo_allocator = None From a31681564d9bd49530df593d87048232acecd842 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:03:21 -0800 Subject: [PATCH 043/215] Fix crash with ace step 1.5 (#12264) --- comfy/text_encoders/ace15.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 73d710671..fce2b67ce 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -19,6 +19,7 @@ def sample_manual_loop_no_classes( min_tokens: int = 1, max_new_tokens: int = 2048, audio_start_id: int = 151669, # The cutoff ID for audio codes + audio_end_id: int = 215669, eos_token_id: int = 151645, ): device = model.execution_device @@ -60,6 +61,7 @@ def sample_manual_loop_no_classes( remove_logit_value = torch.finfo(cfg_logits.dtype).min # Only generate audio tokens cfg_logits[:, :audio_start_id] = remove_logit_value + cfg_logits[:, audio_end_id:] = remove_logit_value if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: cfg_logits[:, eos_token_id] = eos_score From 5087f1d497c5b615fbb5d1ff03fcc1df308bd025 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 4 Feb 2026 00:08:59 -0500 Subject: [PATCH 044/215] ComfyUI v0.12.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 2e2c12ced..5d296cd1b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.12.1" +__version__ = "0.12.2" diff --git a/pyproject.toml b/pyproject.toml index c21ee03f1..1ddcc3596 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.12.1" +version = "0.12.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From d30c609f5aa1d0cc75852815308adbb6ae21c644 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:48:47 -0800 Subject: [PATCH 045/215] utils: safetensors: dont slice data on torch level (#12266) Torch has alignment enforcement when viewing with data type changes but only relative to itself. Do all tensor constructions straight off the memory-view individually so pytorch doesnt see an alignment problem. The is needed for handling misaligned safetensors weights, which are reasonably common in third party models. This limits usage of this safetensors loader to GPU compute only as CPUs kernnel are very likely to bus error. But it works for dynamic_vram, where we really dont want to take a deep copy and we always use GPU copy_ which disentangles the misalignment. --- comfy/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index c1b536833..1337e2205 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -82,14 +82,12 @@ _TYPES = { def load_safetensors(ckpt): f = open(ckpt, "rb") mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + mv = memoryview(mapping) header_size = struct.unpack(" Date: Tue, 3 Feb 2026 23:08:45 -0800 Subject: [PATCH 046/215] mp: Fix checkpoint saving (#12268) Fix regression in the recent model saving refactor. Pass the non unet pieces down the layers so that checkpoints are complete. --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cdf289395..d888dbcfb 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1400,7 +1400,7 @@ class ModelPatcher: continue key = "diffusion_model." + k unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) - return self.model.state_dict_for_saving(unet_state_dict) + return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) def __del__(self): self.unpin_all_weights() From e77b34dfead9758cf5e32b4cffeada0d0c56ab7d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 4 Feb 2026 21:35:38 +0200 Subject: [PATCH 047/215] add File3DAny output to Load3D node; extend SaveGLB to accept File3DAny as input (#12276) * add File3DAny output to Load3D node; extend SaveGLB node to accept File3DAny as input * fix(grammar): capitalize letter --- comfy_extras/nodes_hunyuan3d.py | 12 ++++++++++-- comfy_extras/nodes_load_3d.py | 4 +++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index eda1639ab..c2df3e859 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveGLB", + display_name="Save 3D Model", search_aliases=["export 3d model", "save mesh"], category="3d", is_output_node=True, @@ -626,8 +627,14 @@ class SaveGLB(IO.ComfyNode): IO.Mesh.Input("mesh"), types=[ IO.File3DGLB, + IO.File3DGLTF, + IO.File3DOBJ, + IO.File3DFBX, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, ], - tooltip="Mesh or GLB file to save", + tooltip="Mesh or 3D file to save", ), IO.String.Input("filename_prefix", default="mesh/ComfyUI"), ], @@ -649,7 +656,8 @@ class SaveGLB(IO.ComfyNode): if isinstance(mesh, Types.File3D): # Handle File3D input - save BytesIO data to output folder - f = f"{filename}_{counter:05}_.glb" + ext = mesh.format or "glb" + f = f"{filename}_{counter:05}_.{ext}" mesh.save_to(os.path.join(full_output_folder, f)) results.append({ "filename": f, diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index f29510488..edbb5cd40 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -45,6 +45,7 @@ class Load3D(IO.ComfyNode): IO.Image.Output(display_name="normal"), IO.Load3DCamera.Output(display_name="camera_info"), IO.Video.Output(display_name="recording_video"), + IO.File3DAny.Output(display_name="model_3d"), ], ) @@ -66,7 +67,8 @@ class Load3D(IO.ComfyNode): video = InputImpl.VideoFromFile(recording_video_path) - return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video) + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d) process = execute # TODO: remove From 26dd7eb42180fb57c9da47e60d0a2bac659e47ad Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 15:25:06 -0800 Subject: [PATCH 048/215] Fix ace step nan issue on some hardware/pytorch configs. (#12289) --- comfy/text_encoders/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3afd094d1..b6735d210 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -651,10 +651,10 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1) + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) if mask is not None: mask += causal_mask else: From c8fcbd66eef0ab48d9fe7e4ee35c683a193af46b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 16:37:05 -0800 Subject: [PATCH 049/215] Try to fix ace text encoder slowness on some configs. (#12290) --- comfy/ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 53c5e4dc3..0f4eca7c7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -54,6 +54,8 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): + if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: From 6125b8097952a374009af39639ff45da85f65500 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:29:22 -0800 Subject: [PATCH 050/215] Add llm sampling options and make reference audio work on ace step 1.5 (#12295) --- comfy/ldm/ace/ace_step15.py | 3 +-- comfy/model_base.py | 19 ++++++++++++------- comfy/text_encoders/ace15.py | 31 +++++++++++++++++++++++-------- comfy_extras/nodes_ace.py | 16 +++++++++++----- 4 files changed, 47 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index d90549658..17a37e573 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -1035,8 +1035,7 @@ class AceStepConditionGenerationModel(nn.Module): audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) else: - assert False - # TODO ? + lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed) lm_hints = self.detokenizer(lm_hints_5Hz) diff --git a/comfy/model_base.py b/comfy/model_base.py index 89944548c..a2a34f191 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1548,6 +1548,7 @@ class ACEStep15(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) device = kwargs["device"] + noise = kwargs["noise"] cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: @@ -1571,15 +1572,19 @@ class ACEStep15(BaseModel): 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, - 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750) + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2]) + pass_audio_codes = True else: - refer_audio = refer_audio[-1] + refer_audio = refer_audio[-1][:, :, :noise.shape[2]] + pass_audio_codes = False + + if pass_audio_codes: + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + refer_audio = refer_audio[:, :, :750] + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) - - audio_codes = kwargs.get("audio_codes", None) - if audio_codes is not None: - out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) - return out class Omnigen2(BaseModel): diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index fce2b67ce..74e62733e 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -101,9 +101,7 @@ def sample_manual_loop_no_classes( return output_audio_codes -def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0): - cfg_scale = 2.0 - +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0): positive = [[token for token, _ in inner_list] for inner_list in positive] negative = [[token for token, _ in inner_list] for inner_list in negative] positive = positive[0] @@ -120,7 +118,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102 positive = [model.special_tokens["pad"]] * pos_pad + positive paddings = [pos_pad, neg_pad] - return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): @@ -137,6 +135,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): language = kwargs.get("language", "en") seed = kwargs.get("seed", 0) + generate_audio_codes = kwargs.get("generate_audio_codes", True) + cfg_scale = kwargs.get("cfg_scale", 2.0) + temperature = kwargs.get("temperature", 0.85) + top_p = kwargs.get("top_p", 0.9) + top_k = kwargs.get("top_k", 0.0) + duration = math.ceil(duration) meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" @@ -147,7 +151,14 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) - out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} + out["lm_metadata"] = {"min_tokens": duration * 5, + "seed": seed, + "generate_audio_codes": generate_audio_codes, + "cfg_scale": cfg_scale, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + } return out @@ -203,10 +214,14 @@ class ACE15TEModel(torch.nn.Module): self.qwen3_06b.set_clip_options({"layer": [0]}) lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) - lm_metadata = token_weight_pairs["lm_metadata"] - audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + out = {"conditioning_lyrics": lyrics_embeds[:, 0]} - return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} + lm_metadata = token_weight_pairs["lm_metadata"] + if lm_metadata["generate_audio_codes"]: + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) + out["audio_codes"] = [audio_codes] + + return base_out, None, out def set_clip_options(self, options): self.qwen3_06b.set_clip_options(options) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 376584e5c..dde5bbd2a 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -44,13 +44,18 @@ class TextEncodeAceStepAudio15(io.ComfyNode): io.Combo.Input("timesignature", options=['2', '3', '4', '6']), io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), ], outputs=[io.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput: - tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed) + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k) conditioning = clip.encode_from_tokens_scheduled(tokens) return io.NodeOutput(conditioning) @@ -100,14 +105,15 @@ class EmptyAceStep15LatentAudio(io.ComfyNode): latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "audio"}) -class ReferenceTimbreAudio(io.ComfyNode): +class ReferenceAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ReferenceTimbreAudio", + display_name="Reference Audio", category="advanced/conditioning/audio", is_experimental=True, - description="This node sets the reference audio for timbre (for ace step 1.5)", + description="This node sets the reference audio for ace step 1.5", inputs=[ io.Conditioning.Input("conditioning"), io.Latent.Input("latent", optional=True), @@ -131,7 +137,7 @@ class AceExtension(ComfyExtension): EmptyAceStepLatentAudio, TextEncodeAceStepAudio15, EmptyAceStep15LatentAudio, - ReferenceTimbreAudio, + ReferenceAudio, ] async def comfy_entrypoint() -> AceExtension: From a50c32d63fe55d073edd7af2242f0536f50b362e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:15:30 -0800 Subject: [PATCH 051/215] Disable sage attention on ace step 1.5 (#12297) --- comfy/ldm/ace/ace_step15.py | 2 +- comfy/ldm/modules/attention.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 17a37e573..f2b130bc1 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -183,7 +183,7 @@ class AceStepAttention(nn.Module): else: attn_bias = window_bias - attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False) attn_output = self.o_proj(attn_output) return attn_output diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ccf690945..10d051325 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if kwargs.get("low_precision_attention", True) is False: + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape From a246cc02b274104d5f656b68ce505354c164aef8 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:17:37 -0700 Subject: [PATCH 052/215] Improvements to ACE-Steps 1.5 text encoding (#12283) --- comfy/text_encoders/ace15.py | 56 +++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 74e62733e..00dd5ba90 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,6 +3,7 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math +import yaml import comfy.utils @@ -125,14 +126,43 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) + def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str: + user_metas = { + k: kwargs.pop(k) + for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption") + if k in kwargs + } + timesignature = user_metas.get("timesignature") + if isinstance(timesignature, str) and timesignature.endswith("/4"): + user_metas["timesignature"] = timesignature.rsplit("/", 1)[0] + user_metas = { + k: v if not isinstance(v, str) or not v.isdigit() else int(v) + for k, v in user_metas.items() + if v not in {"unspecified", None} + } + if len(user_metas): + meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip() + else: + meta_yaml = "" + return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml + + def _metas_to_cap(self, **kwargs) -> str: + use_keys = ("bpm", "duration", "keyscale", "timesignature") + user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys } + duration = user_metas["duration"] + if duration == "N/A": + user_metas["duration"] = "30 seconds" + elif isinstance(duration, (str, int, float)): + user_metas["duration"] = f"{math.ceil(float(duration))} seconds" + else: + raise TypeError("Unexpected type for duration key, must be str, int or float") + return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys) + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): out = {} lyrics = kwargs.get("lyrics", "") - bpm = kwargs.get("bpm", 120) duration = kwargs.get("duration", 120) - keyscale = kwargs.get("keyscale", "C major") - timesignature = kwargs.get("timesignature", 2) - language = kwargs.get("language", "en") + language = kwargs.get("language") seed = kwargs.get("seed", 0) generate_audio_codes = kwargs.get("generate_audio_codes", True) @@ -141,16 +171,20 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): top_p = kwargs.get("top_p", 0.9) top_k = kwargs.get("top_k", 0.0) + duration = math.ceil(duration) - meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) - lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" + kwargs["duration"] = duration - meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration) - out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True) - out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True) + cot_text = self._metas_to_cot(caption = text, **kwargs) + meta_cap = self._metas_to_cap(**kwargs) - out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) - out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n" + + out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True) + out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True) + + out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs) + out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed, "generate_audio_codes": generate_audio_codes, From 35183543e004d8b7509c043e7a680bee07171622 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:12:04 -0800 Subject: [PATCH 053/215] Add VAE tiled decode node for audio. (#12299) --- comfy/sd.py | 2 +- comfy_extras/nodes_audio.py | 43 +++++++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index bc63d6ced..bc9407405 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -976,7 +976,7 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1: + if dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index bef723dce..b63dd8e97 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -94,6 +94,19 @@ class VAEEncodeAudio(IO.ComfyNode): encode = execute # TODO: remove +def vae_decode_audio(vae, samples, tile=None, overlap=None): + if tile is not None: + audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1) + else: + audio = vae.decode(samples["samples"]).movedim(-1, 1) + + std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} + + class VAEDecodeAudio(IO.ComfyNode): @classmethod def define_schema(cls): @@ -111,16 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode): @classmethod def execute(cls, vae, samples) -> IO.NodeOutput: - audio = vae.decode(samples["samples"]).movedim(-1, 1) - std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 - std[std < 1.0] = 1.0 - audio /= std - vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) - return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}) + return IO.NodeOutput(vae_decode_audio(vae, samples)) decode = execute # TODO: remove +class VAEDecodeAudioTiled(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudioTiled", + search_aliases=["latent to audio"], + display_name="VAE Decode Audio (Tiled)", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8), + IO.Int.Input("overlap", default=64, min=0, max=1024, step=8), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput: + return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap)) + + class SaveAudio(IO.ComfyNode): @classmethod def define_schema(cls): @@ -675,6 +705,7 @@ class AudioExtension(ComfyExtension): EmptyLatentAudio, VAEEncodeAudio, VAEDecodeAudio, + VAEDecodeAudioTiled, SaveAudio, SaveAudioMP3, SaveAudioOpus, From cb459573c8fa025bbf9ecf312f6af376d659f567 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Feb 2026 01:13:35 -0500 Subject: [PATCH 054/215] ComfyUI v0.12.3 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5d296cd1b..706b37763 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.12.2" +__version__ = "0.12.3" diff --git a/pyproject.toml b/pyproject.toml index 1ddcc3596..f7925b92a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.12.2" +version = "0.12.3" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 00efcc6cd028206ad81a90dec177c9a470a20a2a Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Thu, 5 Feb 2026 15:17:37 +0900 Subject: [PATCH 055/215] Bump comfyui-frontend-package to 1.38.13 (#12238) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0c401873a..41cc9174b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.37.11 +comfyui-frontend-package==1.38.13 comfyui-workflow-templates==0.8.31 comfyui-embedded-docs==0.4.0 torch From 2b70ab9ad0fd6a38b11546a18c546ce40cc176a1 Mon Sep 17 00:00:00 2001 From: AustinMroz Date: Wed, 4 Feb 2026 22:18:21 -0800 Subject: [PATCH 056/215] Add a Create List node (#12173) --- comfy_extras/nodes_toolkit.py | 47 +++++++++++++++++++++++++++++++++++ nodes.py | 3 ++- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_toolkit.py diff --git a/comfy_extras/nodes_toolkit.py b/comfy_extras/nodes_toolkit.py new file mode 100644 index 000000000..71faf7226 --- /dev/null +++ b/comfy_extras/nodes_toolkit.py @@ -0,0 +1,47 @@ +from __future__ import annotations +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class CreateList(io.ComfyNode): + @classmethod + def define_schema(cls): + template_matchtype = io.MatchType.Template("type") + template_autogrow = io.Autogrow.TemplatePrefix( + input=io.MatchType.Input("input", template=template_matchtype), + prefix="input", + ) + return io.Schema( + node_id="CreateList", + display_name="Create List", + category="logic", + is_input_list=True, + search_aliases=["Image Iterator", "Text Iterator", "Iterator"], + inputs=[io.Autogrow.Input("inputs", template=template_autogrow)], + outputs=[ + io.MatchType.Output( + template=template_matchtype, + is_output_list=True, + display_name="list", + ), + ], + ) + + @classmethod + def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput: + output_list = [] + for input in inputs.values(): + output_list += input + return io.NodeOutput(output_list) + + +class ToolkitExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CreateList, + ] + + +async def comfy_entrypoint() -> ToolkitExtension: + return ToolkitExtension() diff --git a/nodes.py b/nodes.py index e11a8ed80..91de7a9d7 100644 --- a/nodes.py +++ b/nodes.py @@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes(): "nodes_image_compare.py", "nodes_zimage.py", "nodes_lora_debug.py", - "nodes_color.py" + "nodes_color.py", + "nodes_toolkit.py", ] import_failed = [] From 6555dc65b82c5f072dcad87f0dbccb4fc5f85e6b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Feb 2026 13:43:45 -0800 Subject: [PATCH 057/215] Make ace step 1.5 work without the llm. (#12311) --- comfy/ldm/ace/ace_step15.py | 72 +++++++++++++++++++++++++++++++++---- comfy/model_base.py | 17 +++------ 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index f2b130bc1..7fc7f1e8e 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -7,6 +7,67 @@ from comfy.ldm.modules.attention import optimized_attention import comfy.model_management from comfy.ldm.flux.layers import timestep_embedding +def get_silence_latent(length, device): + head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291, + -0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436, + -1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736, + 0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261, + 1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041, + 0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460, + -1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742, + -2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202], + [ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440, + 0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907, + -0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100, + -0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840, + 0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091, + 0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974, + -1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682, + -3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473], + [ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03, + 3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01, + -7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00, + 5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00, + 1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01, + -6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00, + 5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02, + -6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01, + 1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01, + 2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01, + -7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01, + -8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01, + 5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01], + [-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03, + 3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01, + -9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00, + 6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00, + 1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01, + -7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00, + 5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01, + -2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01, + 1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01, + 1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01, + -8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01, + -4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01, + 5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1) + + silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, + 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, + -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, + 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, + 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, + -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, + 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, + -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, + 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, + 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, + -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, + -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length) + silence_latent[:, :, :head.shape[-1]] = head + return silence_latent + + def get_layer_class(operations, layer_name): if operations is not None and hasattr(operations, layer_name): return getattr(operations, layer_name) @@ -1040,22 +1101,21 @@ class AceStepConditionGenerationModel(nn.Module): lm_hints = self.detokenizer(lm_hints_5Hz) lm_hints = lm_hints[:, :src_latents.shape[1], :] - if is_covers is None: + if is_covers is None or is_covers is True: src_latents = lm_hints - else: - src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents) + elif is_covers is False: + src_latents = refer_audio_acoustic_hidden_states_packed context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) return encoder_hidden, encoder_mask, context_latents - def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs): + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs): text_attention_mask = None lyric_attention_mask = None refer_audio_order_mask = None attention_mask = None chunk_masks = None - is_covers = None src_latents = None precomputed_lm_hints_25Hz = None lyric_hidden_states = lyric_embed @@ -1067,7 +1127,7 @@ class AceStepConditionGenerationModel(nn.Module): if refer_audio_order_mask is None: refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) - if src_latents is None and is_covers is None: + if src_latents is None: src_latents = x if chunk_masks is None: diff --git a/comfy/model_base.py b/comfy/model_base.py index a2a34f191..dcbf12074 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1560,22 +1560,11 @@ class ACEStep15(BaseModel): refer_audio = kwargs.get("reference_audio_timbre_latents", None) if refer_audio is None or len(refer_audio) == 0: - refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, - 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, - -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, - 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, - 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, - -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, - 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, - -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, - 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, - 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, - -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, - -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, - 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2]) + refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device) pass_audio_codes = True else: refer_audio = refer_audio[-1][:, :, :noise.shape[2]] + out['is_covers'] = comfy.conds.CONDConstant(True) pass_audio_codes = False if pass_audio_codes: @@ -1583,6 +1572,8 @@ class ACEStep15(BaseModel): if audio_codes is not None: out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) refer_audio = refer_audio[:, :, :750] + else: + out['is_covers'] = comfy.conds.CONDConstant(False) out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) return out From 458292fef0077470f5675ba52555e7bb4c28102e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:15:04 -0800 Subject: [PATCH 058/215] Fix some lowvram stuff with ace step 1.5 (#12312) --- comfy/ldm/ace/ace_step15.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 7fc7f1e8e..69338336d 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -738,7 +738,7 @@ class AttentionPooler(nn.Module): def forward(self, x): B, T, P, D = x.shape x = self.embed_tokens(x) - special = self.special_token.expand(B, T, 1, -1) + special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1) x = torch.cat([special, x], dim=2) x = x.view(B * T, P + 1, D) @@ -789,7 +789,7 @@ class FSQ(nn.Module): self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) def bound(self, z): - levels_minus_1 = (self._levels - 1).to(z.dtype) + levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1) scale = 2. / levels_minus_1 bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 @@ -804,8 +804,8 @@ class FSQ(nn.Module): return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. def codes_to_indices(self, zhat): - zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1)) - return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32) + zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1)) + return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32) def forward(self, z): orig_dtype = z.dtype @@ -887,7 +887,7 @@ class ResidualFSQ(nn.Module): x = self.project_in(x) if hasattr(self, 'soft_clamp_input_value'): - sc_val = self.soft_clamp_input_value.to(x.dtype) + sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype) x = (x / sc_val).tanh() * sc_val quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) @@ -895,7 +895,7 @@ class ResidualFSQ(nn.Module): all_indices = [] for layer, scale in zip(self.layers, self.scales): - scale = scale.to(residual.dtype) + scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype) quantized, indices = layer(residual / scale) quantized = quantized * scale From c2d7f07dbf312ef9034c65102f1a45c4a3355c1a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:24:09 -0800 Subject: [PATCH 059/215] Fix issue when using disable_unet_model_creation (#12315) --- comfy/model_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index dcbf12074..3bb54f59e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -147,11 +147,11 @@ class BaseModel(torch.nn.Module): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) - comfy.model_management.archive_model_dtypes(self.diffusion_model) - self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 From a1c101f861681ff18df5bdb0605e63c1ba9e8a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 6 Feb 2026 07:43:09 +0200 Subject: [PATCH 060/215] EasyCache: Support LTX2 (#12231) --- comfy_extras/nodes_easycache.py | 50 ++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 90d730df6..51d1e5b9c 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -9,6 +9,14 @@ if TYPE_CHECKING: from uuid import UUID +def _extract_tensor(data, output_channels): + """Extract tensor from data, handling both single tensors and lists.""" + if isinstance(data, list): + # LTX2 AV tensors: [video, audio] + return data[0][:, :output_channels], data[1][:, :output_channels] + return data[:, :output_channels], None + + def easycache_forward_wrapper(executor, *args, **kwargs): # get values from args transformer_options: dict[str] = args[-1] @@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if not transformer_options: transformer_options = args[-2] easycache: EasyCacheHolder = transformer_options["easycache"] - x: torch.Tensor = args[0][:, :easycache.output_channels] + x, ax = _extract_tensor(args[0], easycache.output_channels) sigmas = transformer_options["sigmas"] uuids = transformer_options["uuids"] if sigmas is not None and easycache.is_past_end_timestep(sigmas): @@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result if easycache.initial_step: easycache.first_cond_uuid = uuids[0] has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) @@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values easycache.skip_current_step = True - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result else: if easycache.verbose: logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") easycache.cumulative_change_rate = 0.0 - output: torch.Tensor = executor(*args, **kwargs) + full_output: torch.Tensor = executor(*args, **kwargs) + output, audio_output = _extract_tensor(full_output, easycache.output_channels) if has_first_cond_uuid and easycache.has_output_prev_norm(): output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() if easycache.verbose: @@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") # TODO: allow cache_diff to be offloaded easycache.update_cache_diff(output, next_x_prev, uuids) + if audio_output is not None: + easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True) if has_first_cond_uuid: easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) easycache.output_prev_subsampled = easycache.subsample(output, uuids) easycache.output_prev_norm = output.flatten().abs().mean() if easycache.verbose: logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") - return output + return full_output def lazycache_predict_noise_wrapper(executor, *args, **kwargs): # get values from args @@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs): easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) + x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels) # prepare next x_prev - x: torch.Tensor = args[0][:, :easycache.output_channels] next_x_prev = x input_change = None do_easycache = easycache.should_do_easycache(timestep) @@ -197,6 +216,7 @@ class EasyCacheHolder: self.output_prev_subsampled: torch.Tensor = None self.output_prev_norm: torch.Tensor = None self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {} self.output_change_rates = [] self.approx_output_change_rates = [] self.total_steps_skipped = 0 @@ -245,20 +265,21 @@ class EasyCacheHolder: def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: return all(uuid in self.uuid_cache_diffs for uuid in uuids) - def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): - if self.first_cond_uuid in uuids: + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + if self.first_cond_uuid in uuids and not is_audio: self.total_steps_skipped += 1 + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs batch_offset = x.shape[0] // len(uuids) for i, uuid in enumerate(uuids): # slice out only what is relevant to this cond batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) - if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if x.shape[1:] != cache_diffs[uuid].shape[1:]: if not self.allow_mismatch: raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") slicing = [] skip_this_dim = True - for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape): if skip_this_dim: skip_this_dim = False continue @@ -270,10 +291,11 @@ class EasyCacheHolder: else: slicing.append(slice(None)) batch_slice = batch_slice + slicing - x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) + x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device) return x - def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) if output.shape[1:] != x.shape[1:]: if not self.allow_mismatch: @@ -293,7 +315,7 @@ class EasyCacheHolder: diff = output - x batch_offset = diff.shape[0] // len(uuids) for i, uuid in enumerate(uuids): - self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: return self.first_cond_uuid in uuids @@ -324,6 +346,8 @@ class EasyCacheHolder: self.output_prev_norm = None del self.uuid_cache_diffs self.uuid_cache_diffs = {} + del self.uuid_cache_diffs_audio + self.uuid_cache_diffs_audio = {} self.total_steps_skipped = 0 self.state_metadata = None return self From eba6c940fd04483fedec6b47bb93fa669e77fe8a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:14:56 -0800 Subject: [PATCH 061/215] Make ace step 1.5 base model work properly with default workflow. (#12337) --- comfy/ldm/ace/ace_step15.py | 5 ++++- comfy/model_base.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 69338336d..1d7dc59a8 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module): return encoder_hidden, encoder_mask, context_latents - def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs): + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs): text_attention_mask = None lyric_attention_mask = None refer_audio_order_mask = None @@ -1140,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module): src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes ) + if replace_with_null_embeds: + enc_hidden[:] = self.null_condition_emb.to(enc_hidden) + out = self.decoder(hidden_states=x, timestep=timestep, timestep_r=timestep, diff --git a/comfy/model_base.py b/comfy/model_base.py index 3bb54f59e..3aa345254 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1552,6 +1552,8 @@ class ACEStep15(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: + if torch.count_nonzero(cross_attn) == 0: + out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) conditioning_lyrics = kwargs.get("conditioning_lyrics", None) From a831c19b703693f561e32780248514eeaa9e832e Mon Sep 17 00:00:00 2001 From: asagi4 <130366179+asagi4@users.noreply.github.com> Date: Sat, 7 Feb 2026 02:38:04 +0200 Subject: [PATCH 062/215] Fix return_word_ids=True with Anima tokenizer (#12328) --- comfy/text_encoders/anima.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index b6f58cb25..fcba097cb 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -23,7 +23,7 @@ class AnimaTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) - out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["qwen3_06b"] = [[(token, 1.0, id) if return_word_ids else (token, 1.0) for token, _, id in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out From 204e65b8dcb2db2014f40e1b4c8def3a00150cde Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:48:20 -0800 Subject: [PATCH 063/215] Fix bug with last pr (#12338) --- comfy/text_encoders/anima.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index fcba097cb..d8c5a6f92 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -23,7 +23,7 @@ class AnimaTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) - out["qwen3_06b"] = [[(token, 1.0, id) if return_word_ids else (token, 1.0) for token, _, id in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out From 6a263288427a9998086603db0e7078ebcb56f0c4 Mon Sep 17 00:00:00 2001 From: tdrussell Date: Fri, 6 Feb 2026 19:12:15 -0600 Subject: [PATCH 064/215] Support fp16 for Cosmos-Predict2 and Anima (#12249) --- comfy/ldm/cosmos/predict2.py | 24 +++++++++++++++++------- comfy/supported_models.py | 4 ++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index c270e6333..6491e486b 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -335,7 +335,7 @@ class FinalLayer(nn.Module): device=None, dtype=None, operations=None ): super().__init__() - self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = operations.Linear( hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype ) @@ -463,6 +463,8 @@ class Block(nn.Module): extra_per_block_pos_emb: Optional[torch.Tensor] = None, transformer_options: Optional[dict] = {}, ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,7 +514,7 @@ class Block(nn.Module): result_B_T_H_W_D = rearrange( self.self_attn( # normalized_x_B_T_HW_D, - rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -522,7 +524,7 @@ class Block(nn.Module): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -536,7 +538,7 @@ class Block(nn.Module): ) _result_B_T_H_W_D = rearrange( self.cross_attn( - rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -555,7 +557,7 @@ class Block(nn.Module): shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -563,8 +565,8 @@ class Block(nn.Module): scale_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D, ) - result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) return x_B_T_H_W_D @@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module): "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "transformer_options": kwargs.get("transformer_options", {}), } + + # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream + # in fp32, but run attention and MLP modules in fp16. + # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable + # quality degradation and visual artifacts. + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + for block in self.blocks: x_B_T_H_W_D = block( x_B_T_H_W_D, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 77264ed28..56a21b0ef 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) @@ -1023,7 +1023,7 @@ class Anima(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) From 039955c52744909107dc68ade0698a55d81a8886 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:14:52 -0800 Subject: [PATCH 065/215] Some fixes to previous pr. (#12339) --- comfy/ldm/cosmos/predict2.py | 2 +- comfy/supported_models.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 6491e486b..2268bff38 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -894,6 +894,6 @@ class MiniTrainDIT(nn.Module): **block_kwargs, ) - x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] return x_B_C_Tt_Hp_Wp diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 56a21b0ef..d33db7507 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1025,10 +1025,6 @@ class Anima(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] - def __init__(self, unet_config): - super().__init__(unet_config) - self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 - def get_model(self, state_dict, prefix="", device=None): out = model_base.Anima(self, device=device) return out @@ -1038,6 +1034,12 @@ class Anima(supported_models_base.BASE): detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs): + self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95 + if dtype is torch.float16: + self.memory_usage_factor *= 1.4 + return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", From 17e7df43d19bde49efa46a32b89f5153b9cb0ded Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:02:11 -0800 Subject: [PATCH 066/215] Pad ace step 1.5 ref audio if not long enough. (#12341) --- comfy/model_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3aa345254..858789b30 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1577,6 +1577,10 @@ class ACEStep15(BaseModel): else: out['is_covers'] = comfy.conds.CONDConstant(False) + if refer_audio.shape[2] < noise.shape[2]: + pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device) + refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2) + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) return out From 5ff4fdedba2e72ffabf2948799a1c656d9002b52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 7 Feb 2026 21:25:30 +0200 Subject: [PATCH 067/215] Fix LazyCache (#12344) --- comfy_extras/nodes_easycache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 51d1e5b9c..b1912392c 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -108,7 +108,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs): easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) - x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels) + x: torch.Tensor = args[0][:, :easycache.output_channels] # prepare next x_prev next_x_prev = x input_change = None From 9bf5aa54dbda5b5de36812cfc10b123ae0930283 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 8 Feb 2026 06:38:51 +0800 Subject: [PATCH 068/215] Add search_aliases to sa-solver and seeds-2 node (#12327) --- comfy_extras/nodes_custom_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 8afd13acf..61a234634 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -622,6 +622,7 @@ class SamplerSASolver(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerSASolver", + search_aliases=["sde"], category="sampling/custom_sampling/samplers", inputs=[ io.Model.Input("model"), @@ -666,6 +667,7 @@ class SamplerSEEDS2(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerSEEDS2", + search_aliases=["sde", "exp heun"], category="sampling/custom_sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), From 3760d74005a6954f54657dc59d9e57fd4c44b3fd Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sun, 8 Feb 2026 07:34:52 +0800 Subject: [PATCH 069/215] chore: update embedded docs to v0.4.1 (#12346) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 41cc9174b..5e34a2a49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.38.13 comfyui-workflow-templates==0.8.31 -comfyui-embedded-docs==0.4.0 +comfyui-embedded-docs==0.4.1 torch torchsde torchvision From f350a842611f4d75da7104c2d2965f45989089b9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 7 Feb 2026 16:16:28 -0800 Subject: [PATCH 070/215] Disable prompt weights for ltxv2. (#12354) --- comfy/text_encoders/lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 26573fb12..3f87dfd6a 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs): class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} From a0302cc6a85dcb950a7308f7a31a224ef54f3d58 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 8 Feb 2026 18:16:40 -0800 Subject: [PATCH 071/215] Make tonemap latent work on any dim latents. (#12363) --- comfy_extras/nodes_latent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 6aecf1561..8d2d7297a 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode): latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] normalized_latent = latent / latent_vector_magnitude - mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True) - std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True) + dims = list(range(1, latent_vector_magnitude.ndim)) + mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True) + std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True) top = (std * 5 + mean) * multiplier From 62315fbb15861e64b917d0a072dad5dc9a15173c Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:16:08 -0800 Subject: [PATCH 072/215] Dynamic VRAM fixes - Ace 1.5 performance + a VRAM leak (#12368) * revert threaded model loader change This change was only needed to get around the pytorch 2.7 mempool bugs, and should have been reverted along with #12260. This fixes a different memory leak where pytorch gets confused about cache emptying. * load non comfy weights * MPDynamic: Pre-generate the tensors for vbars Apparently this is an expensive operation that slows down things. * bump to aimdo 1.8 New features: watermark limit feature logging enhancements -O2 build on linux --- comfy/model_management.py | 37 ++++++------------------------------- comfy/model_patcher.py | 7 ++++++- comfy/ops.py | 2 +- execution.py | 7 ++++++- requirements.txt | 2 +- 5 files changed, 20 insertions(+), 35 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b6291f340..6018c1ab6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -19,7 +19,7 @@ import psutil import logging from enum import Enum -from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +from comfy.cli_args import args, PerformanceFeature import threading import torch import sys @@ -651,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ soft_empty_cache() return unloaded_models -def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): +def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -747,26 +747,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m current_loaded_models.insert(0, loaded_model) return -def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load): - with torch.inference_mode(): - load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) - soft_empty_cache() - -def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): - #Deliberately load models outside of the Aimdo mempool so they can be retained accross - #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are - #thread local. So exploit that to escape context - if enables_dynamic_vram(): - t = threading.Thread( - target=load_models_gpu_thread, - args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) - ) - t.start() - t.join() - else: - load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, - minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) - def load_model_gpu(model): return load_models_gpu([model]) @@ -1226,21 +1206,16 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if dtype is None: dtype = weight._model_dtype - r = torch.empty_like(weight, dtype=dtype, device=device) - signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0] if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): weight._v_signature = signature #Send it over v_tensor.copy_(weight, non_blocking=non_blocking) - #always take a deep copy even if _v is good, as we have no reasonable point to unpin - #a non comfy weight - r.copy_(v_tensor) - comfy_aimdo.model_vbar.vbar_unpin(weight._v) - return r + return v_tensor.to(dtype=dtype) + + r = torch.empty_like(weight, dtype=dtype, device=device) if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: #Offloaded casting could skip this, however it would make the quantizations diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d888dbcfb..b9a117a7c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1492,7 +1492,9 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - #We have way more tools for acceleration on comfy weight offloading, so always + #We force reserve VRAM for the non comfy-weight so we dont have to deal + #with pin and unpin syncrhonization which can be expensive for small weights + #with a high layer rate (e.g. autoregressive LLMs). #prioritize the non-comfy weights (note the order reverse). loading = self._load_list(prio_comfy_cast_weights=True) loading.sort(reverse=True) @@ -1541,6 +1543,7 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None and not hasattr(m, "_v"): m._v = vbar.alloc(v_weight_size) + m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to) allocated_size += v_weight_size else: @@ -1555,8 +1558,10 @@ class ModelPatcherDynamic(ModelPatcher): weight_size = geometry.numel() * geometry.element_size() if vbar is not None and not hasattr(weight, "_v"): weight._v = vbar.alloc(weight_size) + weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to) weight._model_dtype = model_dtype allocated_size += weight_size + vbar.set_watermark_limit(allocated_size) logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") diff --git a/comfy/ops.py b/comfy/ops.py index 0f4eca7c7..ea0d70702 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -87,7 +87,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu signature = comfy_aimdo.model_vbar.vbar_fault(s._v) if signature is not None: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + xfer_dest = s._v_tensor resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) if not resident: diff --git a/execution.py b/execution.py index 3dbab82e6..896862c6b 100644 --- a/execution.py +++ b/execution.py @@ -13,8 +13,11 @@ from contextlib import nullcontext import torch +from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy_aimdo.model_vbar + from latent_preview import set_preview_method import nodes from comfy_execution.caching import ( @@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) finally: if allocator is not None: + if args.verbose == "DEBUG": + comfy_aimdo.model_vbar.vbars_analyze() comfy.model_management.reset_cast_buffers() - torch.cuda.synchronize() + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: pending_async_nodes[unique_id] = output_data diff --git a/requirements.txt b/requirements.txt index 5e34a2a49..4fda07fde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.7 +comfy-aimdo>=0.1.8 requests #non essential dependencies: From baf8c874557f1522a99d47d94faad12b0257c8f1 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:41:49 -0700 Subject: [PATCH 073/215] Iimprovements to ACE-Steps 1.5 text encoding (part 2) (#12350) --- comfy/text_encoders/ace15.py | 114 +++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 33 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 00dd5ba90..5dac644c2 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,6 +3,7 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math +from tqdm.auto import trange import yaml import comfy.utils @@ -23,6 +24,8 @@ def sample_manual_loop_no_classes( audio_end_id: int = 215669, eos_token_id: int = 151645, ): + if ids is None: + return [] device = model.execution_device if execution_dtype is None: @@ -32,6 +35,7 @@ def sample_manual_loop_no_classes( execution_dtype = torch.float32 embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) + embeds_batch = embeds.shape[0] for i, t in enumerate(paddings): attention_mask[i, :t] = 0 attention_mask[i, t:] = 1 @@ -41,22 +45,27 @@ def sample_manual_loop_no_classes( generator = torch.Generator(device=device) generator.manual_seed(seed) model_config = model.transformer.model.config + past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim] for x in range(model_config.num_hidden_layers): - past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0)) progress_bar = comfy.utils.ProgressBar(max_new_tokens) - for step in range(max_new_tokens): + for step in trange(max_new_tokens, desc="LM sampling"): outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) next_token_logits = model.transformer.logits(outputs[0])[:, -1] past_key_values = outputs[2] - cond_logits = next_token_logits[0:1] - uncond_logits = next_token_logits[1:2] - cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) + if cfg_scale != 1.0: + cond_logits = next_token_logits[0:1] + uncond_logits = next_token_logits[1:2] + cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) + else: + cfg_logits = next_token_logits[0:1] - if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step + if use_eos_score: eos_score = cfg_logits[:, eos_token_id].clone() remove_logit_value = torch.finfo(cfg_logits.dtype).min @@ -64,7 +73,7 @@ def sample_manual_loop_no_classes( cfg_logits[:, :audio_start_id] = remove_logit_value cfg_logits[:, audio_end_id:] = remove_logit_value - if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + if use_eos_score: cfg_logits[:, eos_token_id] = eos_score if top_k is not None and top_k > 0: @@ -93,8 +102,8 @@ def sample_manual_loop_no_classes( break embed, _, _, _ = model.process_tokens([[token]], device) - embeds = embed.repeat(2, 1, 1) - attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1) + embeds = embed.repeat(embeds_batch, 1, 1) + attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1) output_audio_codes.append(token - audio_start_id) progress_bar.update_absolute(step) @@ -104,22 +113,29 @@ def sample_manual_loop_no_classes( def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0): positive = [[token for token, _ in inner_list] for inner_list in positive] - negative = [[token for token, _ in inner_list] for inner_list in negative] positive = positive[0] - negative = negative[0] - neg_pad = 0 - if len(negative) < len(positive): - neg_pad = (len(positive) - len(negative)) - negative = [model.special_tokens["pad"]] * neg_pad + negative + if cfg_scale != 1.0: + negative = [[token for token, _ in inner_list] for inner_list in negative] + negative = negative[0] - pos_pad = 0 - if len(negative) > len(positive): - pos_pad = (len(negative) - len(positive)) - positive = [model.special_tokens["pad"]] * pos_pad + positive + neg_pad = 0 + if len(negative) < len(positive): + neg_pad = (len(positive) - len(negative)) + negative = [model.special_tokens["pad"]] * neg_pad + negative - paddings = [pos_pad, neg_pad] - return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + pos_pad = 0 + if len(negative) > len(positive): + pos_pad = (len(negative) - len(positive)) + positive = [model.special_tokens["pad"]] * pos_pad + positive + + paddings = [pos_pad, neg_pad] + ids = [positive, negative] + else: + paddings = [] + ids = [positive] + + return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): @@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str: user_metas = { k: kwargs.pop(k) - for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption") + for k in ("bpm", "duration", "keyscale", "timesignature", "language") if k in kwargs } timesignature = user_metas.get("timesignature") if isinstance(timesignature, str) and timesignature.endswith("/4"): - user_metas["timesignature"] = timesignature.rsplit("/", 1)[0] + user_metas["timesignature"] = timesignature[:-2] user_metas = { k: v if not isinstance(v, str) or not v.isdigit() else int(v) for k, v in user_metas.items() @@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml def _metas_to_cap(self, **kwargs) -> str: - use_keys = ("bpm", "duration", "keyscale", "timesignature") + use_keys = ("bpm", "timesignature", "keyscale", "duration") user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys } + timesignature = user_metas.get("timesignature") + if isinstance(timesignature, str) and timesignature.endswith("/4"): + user_metas["timesignature"] = timesignature[:-2] duration = user_metas["duration"] if duration == "N/A": user_metas["duration"] = "30 seconds" @@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys) def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): - out = {} + text = text.strip() + text_negative = kwargs.get("caption_negative", text).strip() lyrics = kwargs.get("lyrics", "") + lyrics_negative = kwargs.get("lyrics_negative", lyrics) duration = kwargs.get("duration", 120) + if isinstance(duration, str): + duration = float(duration.split(None, 1)[0]) language = kwargs.get("language") seed = kwargs.get("seed", 0) @@ -171,21 +194,46 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): top_p = kwargs.get("top_p", 0.9) top_k = kwargs.get("top_k", 0.0) - duration = math.ceil(duration) kwargs["duration"] = duration + tokens_duration = duration * 5 + min_tokens = int(kwargs.get("min_tokens", tokens_duration)) + max_tokens = int(kwargs.get("max_tokens", tokens_duration)) + + metas_negative = { + k.rsplit("_", 1)[0]: kwargs.pop(k) + for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative") + if k in kwargs + } + if not kwargs.get("use_negative_caption"): + _ = metas_negative.pop("caption", None) cot_text = self._metas_to_cot(caption = text, **kwargs) + cot_text_negative = "\n" if not metas_negative else self._metas_to_cot(**metas_negative) meta_cap = self._metas_to_cap(**kwargs) - lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n" + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n" + lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>" + qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>" - out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True) - out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True) + llm_prompts = { + "lm_prompt": lm_template.format(text, lyrics.strip(), cot_text), + "lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative), + "lyrics": lyrics_template.format(language if language is not None else "", lyrics), + "qwen3_06b": qwen3_06b_template.format(text, meta_cap), + } - out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs) - out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) - out["lm_metadata"] = {"min_tokens": duration * 5, + out = { + prompt_key: self.qwen3_06b.tokenize_with_weights( + prompt, + prompt_key == "qwen3_06b" and return_word_ids, + disable_weights = True, + **kwargs, + ) + for prompt_key, prompt in llm_prompts.items() + } + out["lm_metadata"] = {"min_tokens": min_tokens, + "max_tokens": max_tokens, "seed": seed, "generate_audio_codes": generate_audio_codes, "cfg_scale": cfg_scale, @@ -252,7 +300,7 @@ class ACE15TEModel(torch.nn.Module): lm_metadata = token_weight_pairs["lm_metadata"] if lm_metadata["generate_audio_codes"]: - audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) out["audio_codes"] = [audio_codes] return base_out, None, out From a4be04c5d750cc5d62256f7f86bb5a7c0a78e28d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:45:56 -0800 Subject: [PATCH 074/215] Ace step prompts match now. (#12376) --- comfy/text_encoders/ace15.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 5dac644c2..73697b3c1 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -145,7 +145,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str: user_metas = { k: kwargs.pop(k) - for k in ("bpm", "duration", "keyscale", "timesignature", "language") + for k in ("bpm", "duration", "keyscale", "timesignature") if k in kwargs } timesignature = user_metas.get("timesignature") @@ -208,8 +208,8 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): if not kwargs.get("use_negative_caption"): _ = metas_negative.pop("caption", None) - cot_text = self._metas_to_cot(caption = text, **kwargs) - cot_text_negative = "\n" if not metas_negative else self._metas_to_cot(**metas_negative) + cot_text = self._metas_to_cot(caption=text, **kwargs) + cot_text_negative = "\n\n" if not metas_negative else self._metas_to_cot(**metas_negative) meta_cap = self._metas_to_cap(**kwargs) lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n" From 349a636a2b0f15aba2930b9af905bb805d2fe30b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 10 Feb 2026 10:25:34 +0800 Subject: [PATCH 075/215] chore: update workflow templates to v0.8.37 (#12377) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4fda07fde..4e2773f5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.38.13 -comfyui-workflow-templates==0.8.31 +comfyui-workflow-templates==0.8.37 comfyui-embedded-docs==0.4.1 torch torchsde From c1b63a7e78b606bc14cd49a02e9338274db28a60 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:58:27 +0200 Subject: [PATCH 076/215] fix(Moonvalley-API-Nodes): adjust "steps" parameter to not raise exception (#12370) --- comfy_api_nodes/nodes_moonvalley.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 08315fa2b..78a230529 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): ), IO.Int.Input( "steps", - default=33, - min=1, + default=80, + min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) max=100, step=1, tooltip="Number of denoising steps", @@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): ), IO.Int.Input( "steps", - default=33, - min=1, + default=60, + min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24) max=100, step=1, display_mode=IO.NumberDisplay.number, @@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): video: Input.Video | None = None, control_type: str = "Motion Transfer", motion_intensity: int | None = 100, - steps=33, + steps=60, prompt_adherence=4.5, ) -> IO.NodeOutput: validated_video = validate_video_to_video_input(video) @@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): ), IO.Int.Input( "steps", - default=33, - min=1, + default=80, + min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) max=100, step=1, tooltip="Inference steps", From 8ca842a8edb26006e730e631ec1153cd42f46d3b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 10 Feb 2026 19:34:54 +0200 Subject: [PATCH 077/215] feat(api-nodes-Kling): add new models (V3, O3) (#12389) * feat(api-nodes-Kling): add new models (V3, O3) * remove storyboard from VideoToVideo node * added check for total duration of storyboards * fixed other small things * updated display name for nodes * added "fake" seed --- comfy_api_nodes/apis/__init__.py | 8 +- comfy_api_nodes/apis/kling.py | 46 +- comfy_api_nodes/nodes_kling.py | 764 ++++++++++++++++++++++++++++--- 3 files changed, 750 insertions(+), 68 deletions(-) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index ee2aa1ce6..46a583b5e 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum): face = 'face' -class KlingImageGenModelName(str, Enum): - kling_v1 = 'kling-v1' - kling_v1_5 = 'kling-v1-5' - kling_v2 = 'kling-v2' - - class KlingImageGenerationsRequest(BaseModel): aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9' callback_url: Optional[AnyUrl] = Field( @@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel): 0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0 ) image_reference: Optional[KlingImageGenImageReferenceType] = None - model_name: Optional[KlingImageGenModelName] = 'kling-v1' + model_name: str = Field(...) n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9) negative_prompt: Optional[str] = Field( None, description='Negative text prompt', max_length=200 diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index bf54ede3e..9c0446075 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -1,12 +1,22 @@ from pydantic import BaseModel, Field +class MultiPromptEntry(BaseModel): + index: int = Field(...) + prompt: str = Field(...) + duration: str = Field(...) + + class OmniProText2VideoRequest(BaseModel): model_name: str = Field(..., description="kling-video-o1") aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") duration: str = Field(..., description="'5' or '10'") prompt: str = Field(...) mode: str = Field("pro") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) + sound: str = Field(..., description="'on' or 'off'") class OmniParamImage(BaseModel): @@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel): duration: str = Field(..., description="'5' or '10'") prompt: str = Field(...) mode: str = Field("pro") + sound: str | None = Field(None, description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class OmniProReferences2VideoRequest(BaseModel): @@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel): duration: str | None = Field(..., description="From 3 to 10.") prompt: str = Field(...) mode: str = Field("pro") + sound: str | None = Field(None, description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class TaskStatusVideoResult(BaseModel): @@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel): class TaskStatusResults(BaseModel): videos: list[TaskStatusVideoResult] | None = Field(None) images: list[TaskStatusImageResult] | None = Field(None) + series_images: list[TaskStatusImageResult] | None = Field(None) class TaskStatusResponseData(BaseModel): @@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel): class OmniProImageRequest(BaseModel): - model_name: str = Field(..., description="kling-image-o1") - resolution: str = Field(..., description="'1k' or '2k'") + model_name: str = Field(...) + resolution: str = Field(...) aspect_ratio: str | None = Field(...) prompt: str = Field(...) mode: str = Field("pro") n: int | None = Field(1, le=9) image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) + result_type: str | None = Field(None, description="Set to 'series' for series generation") + series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series") class TextToVideoWithAudioRequest(BaseModel): - model_name: str = Field(..., description="kling-v2-6") + model_name: str = Field(...) aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") - duration: str = Field(..., description="'5' or '10'") - prompt: str = Field(...) + duration: str = Field(...) + prompt: str | None = Field(...) + negative_prompt: str | None = Field(None) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class ImageToVideoWithAudioRequest(BaseModel): - model_name: str = Field(..., description="kling-v2-6") + model_name: str = Field(...) image: str = Field(...) - duration: str = Field(..., description="'5' or '10'") - prompt: str = Field(...) + image_tail: str | None = Field(None) + duration: str = Field(...) + prompt: str | None = Field(...) + negative_prompt: str | None = Field(None) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class MotionControlRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 739fe1855..b89c85561 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -38,7 +38,6 @@ from comfy_api_nodes.apis import ( KlingImageGenerationsRequest, KlingImageGenerationsResponse, KlingImageGenImageReferenceType, - KlingImageGenModelName, KlingImageGenAspectRatio, KlingVideoEffectsRequest, KlingVideoEffectsResponse, @@ -52,6 +51,7 @@ from comfy_api_nodes.apis import ( from comfy_api_nodes.apis.kling import ( ImageToVideoWithAudioRequest, MotionControlRequest, + MultiPromptEntry, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -71,6 +71,7 @@ from comfy_api_nodes.util import ( sync_op, tensor_to_base64_string, upload_audio_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, validate_image_aspect_ratio, @@ -80,6 +81,31 @@ from comfy_api_nodes.util import ( validate_video_duration, ) + +def _generate_storyboard_inputs(count: int) -> list: + inputs = [] + for i in range(1, count + 1): + inputs.extend( + [ + IO.String.Input( + f"storyboard_{i}_prompt", + multiline=True, + default="", + tooltip=f"Prompt for storyboard segment {i}. Max 512 characters.", + ), + IO.Int.Input( + f"storyboard_{i}_duration", + default=4, + min=1, + max=15, + display_mode=IO.NumberDisplay.slider, + tooltip=f"Duration for storyboard segment {i} in seconds.", + ), + ] + ) + return inputs + + KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video" @@ -820,20 +846,48 @@ class OmniProTextToVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProTextToVideoNode", - display_name="Kling Omni Text to Video (Pro)", + display_name="Kling 3.0 Omni Text to Video", category="api node/video/Kling", description="Use text prompts to generate videos with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the video content. " - "This can include both positive and negative descriptions.", + "This can include both positive and negative descriptions. " + "Ignored when storyboards are enabled.", ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), - IO.Combo.Input("duration", options=[5, 10]), + IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.DynamicCombo.Input( + "storyboards", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations. " + "Ignored for o1 model.", + optional=True, + ), + IO.Boolean.Input("generate_audio", default=False, optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -845,11 +899,15 @@ class OmniProTextToVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $rates := {"std": 0.084, "pro": 0.112}; + $isV3 := $contains(widgets.model_name, "v3"); + $audio := $isV3 and widgets.generate_audio; + $rates := $audio + ? {"std": 0.112, "pro": 0.14} + : {"std": 0.084, "pro": 0.112}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -864,8 +922,45 @@ class OmniProTextToVideoNode(IO.ComfyNode): aspect_ratio: str, duration: int, resolution: str = "1080p", + storyboards: dict | None = None, + generate_audio: bool = False, + seed: int = 0, ) -> IO.NodeOutput: - validate_string(prompt, min_length=1, max_length=2500) + _ = seed + if model_name == "kling-video-o1": + if duration not in (5, 10): + raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") + if generate_audio: + raise ValueError("kling-video-o1 does not support audio generation.") + stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" + if stories_enabled and model_name == "kling-video-o1": + raise ValueError("kling-video-o1 does not support storyboards.") + validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500) + + multi_shot = None + multi_prompt_list = None + if stories_enabled: + count = int(storyboards["storyboards"].split()[0]) + multi_shot = True + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = storyboards[f"storyboard_{i}_prompt"] + sb_duration = storyboards[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list) + if total_storyboard_duration != duration: + raise ValueError( + f"Total storyboard duration ({total_storyboard_duration}s) " + f"must equal the global duration ({duration}s)." + ) + response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -876,6 +971,10 @@ class OmniProTextToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), mode="pro" if resolution == "1080p" else "std", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, + sound="on" if generate_audio else "off", ), ) return await finish_omni_video_task(cls, response) @@ -887,24 +986,26 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProFirstLastFrameNode", - display_name="Kling Omni First-Last-Frame to Video (Pro)", + display_name="Kling 3.0 Omni First-Last-Frame to Video", category="api node/video/Kling", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the video content. " - "This can include both positive and negative descriptions.", + "This can include both positive and negative descriptions. " + "Ignored when storyboards are enabled.", ), - IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Image.Input("first_frame"), IO.Image.Input( "end_frame", optional=True, tooltip="An optional end frame for the video. " - "This cannot be used simultaneously with 'reference_images'.", + "This cannot be used simultaneously with 'reference_images'. " + "Does not work with storyboards.", ), IO.Image.Input( "reference_images", @@ -912,6 +1013,38 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): tooltip="Up to 6 additional reference images.", ), IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.DynamicCombo.Input( + "storyboards", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations. " + "Only supported for kling-v3-omni.", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="Generate audio for the video. Only supported for kling-v3-omni.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -923,11 +1056,15 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $rates := {"std": 0.084, "pro": 0.112}; + $isV3 := $contains(widgets.model_name, "v3"); + $audio := $isV3 and widgets.generate_audio; + $rates := $audio + ? {"std": 0.112, "pro": 0.14} + : {"std": 0.084, "pro": 0.112}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -944,15 +1081,59 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, resolution: str = "1080p", + storyboards: dict | None = None, + generate_audio: bool = False, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed + if model_name == "kling-video-o1": + if duration > 10: + raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") + if generate_audio: + raise ValueError("kling-video-o1 does not support audio generation.") + stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" + if stories_enabled and model_name == "kling-video-o1": + raise ValueError("kling-video-o1 does not support storyboards.") prompt = normalize_omni_prompt_references(prompt) - validate_string(prompt, min_length=1, max_length=2500) + validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") - if duration not in (5, 10) and end_frame is None and reference_images is None: + if end_frame is not None and stories_enabled: + raise ValueError("The 'end_frame' input cannot be used simultaneously with storyboards.") + if ( + model_name == "kling-video-o1" + and duration not in (5, 10) + and end_frame is None + and reference_images is None + ): raise ValueError( "Duration is only supported for 5 or 10 seconds if there is no end frame or reference images." ) + + multi_shot = None + multi_prompt_list = None + if stories_enabled: + count = int(storyboards["storyboards"].split()[0]) + multi_shot = True + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = storyboards[f"storyboard_{i}_prompt"] + sb_duration = storyboards[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list) + if total_storyboard_duration != duration: + raise ValueError( + f"Total storyboard duration ({total_storyboard_duration}s) " + f"must equal the global duration ({duration}s)." + ) + validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ @@ -988,6 +1169,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): duration=str(duration), image_list=image_list, mode="pro" if resolution == "1080p" else "std", + sound="on" if generate_audio else "off", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, ), ) return await finish_omni_video_task(cls, response) @@ -999,24 +1184,57 @@ class OmniProImageToVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProImageToVideoNode", - display_name="Kling Omni Image to Video (Pro)", + display_name="Kling 3.0 Omni Image to Video", category="api node/video/Kling", description="Use up to 7 reference images to generate a video with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the video content. " - "This can include both positive and negative descriptions.", + "This can include both positive and negative descriptions. " + "Ignored when storyboards are enabled.", ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), - IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Image.Input( "reference_images", tooltip="Up to 7 reference images.", ), IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.DynamicCombo.Input( + "storyboards", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations. " + "Only supported for kling-v3-omni.", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="Generate audio for the video. Only supported for kling-v3-omni.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -1028,11 +1246,15 @@ class OmniProImageToVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $rates := {"std": 0.084, "pro": 0.112}; + $isV3 := $contains(widgets.model_name, "v3"); + $audio := $isV3 and widgets.generate_audio; + $rates := $audio + ? {"std": 0.112, "pro": 0.14} + : {"std": 0.084, "pro": 0.112}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1048,9 +1270,46 @@ class OmniProImageToVideoNode(IO.ComfyNode): duration: int, reference_images: Input.Image, resolution: str = "1080p", + storyboards: dict | None = None, + generate_audio: bool = False, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed + if model_name == "kling-video-o1": + if duration > 10: + raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") + if generate_audio: + raise ValueError("kling-video-o1 does not support audio generation.") + stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" + if stories_enabled and model_name == "kling-video-o1": + raise ValueError("kling-video-o1 does not support storyboards.") prompt = normalize_omni_prompt_references(prompt) - validate_string(prompt, min_length=1, max_length=2500) + validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500) + + multi_shot = None + multi_prompt_list = None + if stories_enabled: + count = int(storyboards["storyboards"].split()[0]) + multi_shot = True + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = storyboards[f"storyboard_{i}_prompt"] + sb_duration = storyboards[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list) + if total_storyboard_duration != duration: + raise ValueError( + f"Total storyboard duration ({total_storyboard_duration}s) " + f"must equal the global duration ({duration}s)." + ) + if get_number_of_images(reference_images) > 7: raise ValueError("The maximum number of reference images is 7.") for i in reference_images: @@ -1070,6 +1329,10 @@ class OmniProImageToVideoNode(IO.ComfyNode): duration=str(duration), image_list=image_list, mode="pro" if resolution == "1080p" else "std", + sound="on" if generate_audio else "off", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, ), ) return await finish_omni_video_task(cls, response) @@ -1081,11 +1344,11 @@ class OmniProVideoToVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProVideoToVideoNode", - display_name="Kling Omni Video to Video (Pro)", + display_name="Kling 3.0 Omni Video to Video", category="api node/video/Kling", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, @@ -1102,6 +1365,17 @@ class OmniProVideoToVideoNode(IO.ComfyNode): optional=True, ), IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -1135,7 +1409,9 @@ class OmniProVideoToVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, resolution: str = "1080p", + seed: int = 0, ) -> IO.NodeOutput: + _ = seed prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) @@ -1179,11 +1455,11 @@ class OmniProEditVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProEditVideoNode", - display_name="Kling Omni Edit Video (Pro)", + display_name="Kling 3.0 Omni Edit Video", category="api node/video/Kling", description="Edit an existing video with the latest model from Kling.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, @@ -1198,6 +1474,17 @@ class OmniProEditVideoNode(IO.ComfyNode): optional=True, ), IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -1229,7 +1516,9 @@ class OmniProEditVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, resolution: str = "1080p", + seed: int = 0, ) -> IO.NodeOutput: + _ = seed prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(video, min_duration=3.0, max_duration=10.05) @@ -1273,27 +1562,43 @@ class OmniProImageNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProImageNode", - display_name="Kling Omni Image (Pro)", + display_name="Kling 3.0 Omni Image", category="api node/image/Kling", description="Create or edit images with the latest model from Kling.", inputs=[ - IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the image content. " "This can include both positive and negative descriptions.", ), - IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input("resolution", options=["1K", "2K", "4K"]), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], ), + IO.Combo.Input( + "series_amount", + options=["disabled", "2", "3", "4", "5", "6", "7", "8", "9"], + tooltip="Generate a series of images. Not supported for kling-image-o1.", + ), IO.Image.Input( "reference_images", tooltip="Up to 10 additional reference images.", optional=True, ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -1305,7 +1610,16 @@ class OmniProImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.028}""", + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "series_amount", "model_name"]), + expr=""" + ( + $prices := {"1k": 0.028, "2k": 0.028, "4k": 0.056}; + $base := $lookup($prices, widgets.resolution); + $isO1 := widgets.model_name = "kling-image-o1"; + $mult := ($isO1 or widgets.series_amount = "disabled") ? 1 : $number(widgets.series_amount); + {"type":"usd","usd": $base * $mult} + ) + """, ), ) @@ -1316,8 +1630,13 @@ class OmniProImageNode(IO.ComfyNode): prompt: str, resolution: str, aspect_ratio: str, + series_amount: str = "disabled", reference_images: Input.Image | None = None, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed + if model_name == "kling-image-o1" and resolution == "4K": + raise ValueError("4K resolution is not supported for kling-image-o1 model.") prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) image_list: list[OmniImageParamImage] = [] @@ -1329,6 +1648,9 @@ class OmniProImageNode(IO.ComfyNode): validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): image_list.append(OmniImageParamImage(image=i)) + use_series = series_amount != "disabled" + if use_series and model_name == "kling-image-o1": + raise ValueError("kling-image-o1 does not support series generation.") response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), @@ -1339,6 +1661,8 @@ class OmniProImageNode(IO.ComfyNode): resolution=resolution.lower(), aspect_ratio=aspect_ratio, image_list=image_list if image_list else None, + result_type="series" if use_series else None, + series_amount=int(series_amount) if use_series else None, ), ) if response.code: @@ -1351,7 +1675,9 @@ class OmniProImageNode(IO.ComfyNode): response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + images = final_response.data.task_result.series_images or final_response.data.task_result.images + tensors = [await download_url_to_image_tensor(img.url) for img in images] + return IO.NodeOutput(torch.cat(tensors, dim=0)) class KlingCameraControlT2VNode(IO.ComfyNode): @@ -2119,7 +2445,7 @@ class KlingImageGenerationNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingImageGenerationNode", - display_name="Kling Image Generation", + display_name="Kling 3.0 Image", category="api node/image/Kling", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", inputs=[ @@ -2147,11 +2473,7 @@ class KlingImageGenerationNode(IO.ComfyNode): display_mode=IO.NumberDisplay.slider, tooltip="Subject reference similarity", ), - IO.Combo.Input( - "model_name", - options=[i.value for i in KlingImageGenModelName], - default="kling-v2", - ), + IO.Combo.Input("model_name", options=["kling-v3", "kling-v2", "kling-v1-5"]), IO.Combo.Input( "aspect_ratio", options=[i.value for i in KlingImageGenAspectRatio], @@ -2165,6 +2487,17 @@ class KlingImageGenerationNode(IO.ComfyNode): tooltip="Number of generated images", ), IO.Image.Input("image", optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -2183,7 +2516,7 @@ class KlingImageGenerationNode(IO.ComfyNode): $base := $contains($m,"kling-v1-5") ? (inputs.image.connected ? 0.028 : 0.014) - : ($contains($m,"kling-v1") ? 0.0035 : 0.014); + : $contains($m,"kling-v3") ? 0.028 : 0.014; {"type":"usd","usd": $base * widgets.n} ) """, @@ -2193,7 +2526,7 @@ class KlingImageGenerationNode(IO.ComfyNode): @classmethod async def execute( cls, - model_name: KlingImageGenModelName, + model_name: str, prompt: str, negative_prompt: str, image_type: KlingImageGenImageReferenceType, @@ -2202,17 +2535,11 @@ class KlingImageGenerationNode(IO.ComfyNode): n: int, aspect_ratio: KlingImageGenAspectRatio, image: torch.Tensor | None = None, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) - - if image is None: - image_type = None - elif model_name == KlingImageGenModelName.kling_v1: - raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.") - else: - image = tensor_to_base64_string(image) - task_creation_response = await sync_op( cls, ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), @@ -2221,8 +2548,8 @@ class KlingImageGenerationNode(IO.ComfyNode): model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, - image=image, - image_reference=image_type, + image=tensor_to_base64_string(image) if image is not None else None, + image_reference=image_type if image is not None else None, image_fidelity=image_fidelity, human_fidelity=human_fidelity, n=n, @@ -2252,7 +2579,7 @@ class TextToVideoWithAudio(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingTextToVideoWithAudio", - display_name="Kling Text to Video with Audio", + display_name="Kling 2.6 Text to Video with Audio", category="api node/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), @@ -2320,7 +2647,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingImageToVideoWithAudio", - display_name="Kling Image(First Frame) to Video with Audio", + display_name="Kling 2.6 Image(First Frame) to Video with Audio", category="api node/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), @@ -2478,6 +2805,335 @@ class MotionControl(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class KlingVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingVideoNode", + display_name="Kling 3.0 Video", + category="api node/video/Kling", + description="Generate videos with Kling V3. " + "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", + inputs=[ + IO.DynamicCombo.Input( + "multi_shot", + options=[ + IO.DynamicCombo.Option( + "disabled", + [ + IO.String.Input("prompt", multiline=True, default=""), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations.", + ), + IO.Boolean.Input("generate_audio", default=True), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "kling-v3", + [ + IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1"], + tooltip="Ignored in image-to-video mode.", + ), + ], + ), + ], + tooltip="Model and generation settings.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + IO.Image.Input( + "start_frame", + optional=True, + tooltip="Optional start frame image. When connected, switches to image-to-video mode.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model.resolution", + "generate_audio", + "multi_shot", + "multi_shot.duration", + "multi_shot.storyboard_1_duration", + "multi_shot.storyboard_2_duration", + "multi_shot.storyboard_3_duration", + "multi_shot.storyboard_4_duration", + "multi_shot.storyboard_5_duration", + "multi_shot.storyboard_6_duration", + ], + ), + expr=""" + ( + $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $res := $lookup(widgets, "model.resolution"); + $audio := widgets.generate_audio ? "on" : "off"; + $rate := $lookup($lookup($rates, $res), $audio); + $ms := widgets.multi_shot; + $isSb := $ms != "disabled"; + $n := $isSb ? $number($substring($ms, 0, 1)) : 0; + $d1 := $lookup(widgets, "multi_shot.storyboard_1_duration"); + $d2 := $n >= 2 ? $lookup(widgets, "multi_shot.storyboard_2_duration") : 0; + $d3 := $n >= 3 ? $lookup(widgets, "multi_shot.storyboard_3_duration") : 0; + $d4 := $n >= 4 ? $lookup(widgets, "multi_shot.storyboard_4_duration") : 0; + $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0; + $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0; + $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration"); + {"type":"usd","usd": $rate * $dur} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + multi_shot: dict, + generate_audio: bool, + model: dict, + seed: int, + start_frame: Input.Image | None = None, + ) -> IO.NodeOutput: + _ = seed + mode = "pro" if model["resolution"] == "1080p" else "std" + custom_multi_shot = False + if multi_shot["multi_shot"] == "disabled": + shot_type = None + else: + shot_type = "customize" + custom_multi_shot = True + + multi_prompt_list = None + if shot_type == "customize": + count = int(multi_shot["multi_shot"].split()[0]) + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = multi_shot[f"storyboard_{i}_prompt"] + sb_duration = multi_shot[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + duration = sum(int(e.duration) for e in multi_prompt_list) + if duration < 3 or duration > 15: + raise ValueError( + f"Total storyboard duration ({duration}s) must be between 3 and 15 seconds." + ) + else: + duration = multi_shot["duration"] + validate_string(multi_shot["prompt"], min_length=1, max_length=2500) + + if start_frame is not None: + validate_image_dimensions(start_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) + image_url = await upload_image_to_comfyapi(cls, start_frame, wait_label="Uploading start frame") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), + response_model=TaskStatusResponse, + data=ImageToVideoWithAudioRequest( + model_name=model["model"], + image=image_url, + prompt=None if custom_multi_shot else multi_shot["prompt"], + negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + multi_shot=True if shot_type else None, + multi_prompt=multi_prompt_list, + shot_type=shot_type, + ), + ) + poll_path = f"/proxy/kling/v1/videos/image2video/{response.data.task_id}" + else: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"), + response_model=TaskStatusResponse, + data=TextToVideoWithAudioRequest( + model_name=model["model"], + aspect_ratio=model["aspect_ratio"], + prompt=None if custom_multi_shot else multi_shot["prompt"], + negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + multi_shot=True if shot_type else None, + multi_prompt=multi_prompt_list, + shot_type=shot_type, + ), + ) + poll_path = f"/proxy/kling/v1/videos/text2video/{response.data.task_id}" + + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=poll_path), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class KlingFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingFirstLastFrameNode", + display_name="Kling 3.0 First-Last-Frame to Video", + category="api node/video/Kling", + description="Generate videos with Kling V3 using first and last frames.", + inputs=[ + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + display_mode=IO.NumberDisplay.slider, + ), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.Boolean.Input("generate_audio", default=True), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "kling-v3", + [ + IO.Combo.Input("resolution", options=["1080p", "720p"]), + ], + ), + ], + tooltip="Model and generation settings.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model.resolution", "generate_audio", "duration"], + ), + expr=""" + ( + $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $res := $lookup(widgets, "model.resolution"); + $audio := widgets.generate_audio ? "on" : "off"; + $rate := $lookup($lookup($rates, $res), $audio); + {"type":"usd","usd": $rate * widgets.duration} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image, + generate_audio: bool, + model: dict, + seed: int, + ) -> IO.NodeOutput: + _ = seed + validate_string(prompt, min_length=1, max_length=2500) + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") + image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), + response_model=TaskStatusResponse, + data=ImageToVideoWithAudioRequest( + model_name=model["model"], + image=image_url, + image_tail=image_tail_url, + prompt=prompt, + mode="pro" if model["resolution"] == "1080p" else "std", + duration=str(duration), + sound="on" if generate_audio else "off", + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2504,6 +3160,8 @@ class KlingExtension(ComfyExtension): TextToVideoWithAudio, ImageToVideoWithAudio, MotionControl, + KlingVideoNode, + KlingFirstLastFrameNode, ] From 6615db925c9f84843e29db118852e14b643a1a03 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 11 Feb 2026 02:24:56 +0800 Subject: [PATCH 078/215] chore: update workflow templates to v0.8.38 (#12394) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4e2773f5d..7de6a413c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.38.13 -comfyui-workflow-templates==0.8.37 +comfyui-workflow-templates==0.8.38 comfyui-embedded-docs==0.4.1 torch torchsde From 6648ab68bc934a185c90a2a872c87dc64d093751 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 10 Feb 2026 13:26:29 -0500 Subject: [PATCH 079/215] ComfyUI v0.13.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 706b37763..cf4e89816 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.12.3" +__version__ = "0.13.0" diff --git a/pyproject.toml b/pyproject.toml index f7925b92a..9dab9a50c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.12.3" +version = "0.13.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From fe053ba5eb34c8abcc5d17a25c114340af1833aa Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:37:17 -0800 Subject: [PATCH 080/215] mp: dont deep-clone objects from model_options (#12382) If there are non-trivial python objects nested in the model_options, this causes all sorts of issues. Traverse lists and dicts so clones can safely overide settings and BYO objects but stop there on the deepclone. --- comfy/model_patcher.py | 3 +-- comfy/utils.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b9a117a7c..19c9031ea 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -19,7 +19,6 @@ from __future__ import annotations import collections -import copy import inspect import logging import math @@ -317,7 +316,7 @@ class ModelPatcher: n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() - n.model_options = copy.deepcopy(self.model_options) + n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self diff --git a/comfy/utils.py b/comfy/utils.py index 1337e2205..edd80cebe 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1376,3 +1376,21 @@ def string_to_seed(data): else: crc >>= 1 return crc ^ 0xFFFFFFFF + +def deepcopy_list_dict(obj, memo=None): + if memo is None: + memo = {} + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + + if isinstance(obj, dict): + res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()} + elif isinstance(obj, list): + res = [deepcopy_list_dict(i, memo) for i in obj] + else: + res = obj + + memo[obj_id] = res + return res From f719f9c06266e7944683009b403e995d4c61d5f0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:37:46 -0800 Subject: [PATCH 081/215] sd: delay VAE dtype archive until after override (#12388) VAEs have host specific dtype logic that should override the dynamic _model_dtype. Defer the archiving of model dtypes until after. --- comfy/sd.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index bc9407405..f65e7cadd 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -793,8 +793,6 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - model_management.archive_model_dtypes(self.first_stage_model) - if device is None: device = model_management.vae_device() self.device = device @@ -803,6 +801,7 @@ class VAE: dtype = model_management.vae_dtype(self.device, self.working_dtypes) self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) + model_management.archive_model_dtypes(self.first_stage_model) self.output_device = model_management.intermediate_device() mp = comfy.model_patcher.CoreModelPatcher From 123a7874a97c4a8b8f06d4b7c2b1a566b8f0d057 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:38:28 -0800 Subject: [PATCH 082/215] ops: Fix vanilla-fp8 loaded lora quality (#12390) This was missing the stochastic rounding required for fp8 downcast to be consistent with model_patcher.patch_weight_to_device. Missed in testing as I spend too much time with quantized tensors and overlooked the simpler ones. --- comfy/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index ea0d70702..33803b223 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -169,8 +169,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if orig.dtype == dtype and len(fns) == 0: #The layer actually wants our freshly saved QT x = y - else: - y = x + elif update_weight: + y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key)) if update_weight: orig.copy_(y) for f in fns: From 00fff6019ecf0f4306005579e93cef0cd51a3a1c Mon Sep 17 00:00:00 2001 From: guill Date: Tue, 10 Feb 2026 14:37:14 -0800 Subject: [PATCH 083/215] feat(jobs): add 3d to PREVIEWABLE_MEDIA_TYPES for first-class 3D output support (#12381) Co-authored-by: Jedrzej Kosinski --- comfy_execution/jobs.py | 79 +++++++++++-- tests/execution/test_jobs.py | 208 ++++++++++++++++++++++++++++++++++- 2 files changed, 271 insertions(+), 16 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index bf091a448..370014fb6 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -20,10 +20,60 @@ class JobStatus: # Media types that can be previewed in the frontend -PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'}) +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'}) # 3D file extensions for preview fallback (no dedicated media_type exists) -THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'}) +THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) + + +def has_3d_extension(filename: str) -> bool: + lower = filename.lower() + return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS) + + +def normalize_output_item(item): + """Normalize a single output list item for the jobs API. + + Returns the normalized item, or None to exclude it. + String items with 3D extensions become {filename, type, subfolder} dicts. + """ + if item is None: + return None + if isinstance(item, str): + if has_3d_extension(item): + return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'} + return None + if isinstance(item, dict): + return item + return None + + +def normalize_outputs(outputs: dict) -> dict: + """Normalize raw node outputs for the jobs API. + + Transforms string 3D filenames into file output dicts and removes + None items. All other items (non-3D strings, dicts, etc.) are + preserved as-is. + """ + normalized = {} + for node_id, node_outputs in outputs.items(): + if not isinstance(node_outputs, dict): + normalized[node_id] = node_outputs + continue + normalized_node = {} + for media_type, items in node_outputs.items(): + if media_type == 'animated' or not isinstance(items, list): + normalized_node[media_type] = items + continue + normalized_items = [] + for item in items: + if item is None: + continue + norm = normalize_output_item(item) + normalized_items.append(norm if norm is not None else item) + normalized_node[media_type] = normalized_items + normalized[node_id] = normalized_node + return normalized def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: @@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool: Maintains backwards compatibility with existing logic. Priority: - 1. media_type is 'images', 'video', or 'audio' + 1. media_type is 'images', 'video', 'audio', or '3d' 2. format field starts with 'video/' or 'audio/' - 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb) + 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz) """ if media_type in PREVIEWABLE_MEDIA_TYPES: return True @@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: }) if include_outputs: - job['outputs'] = outputs + job['outputs'] = normalize_outputs(outputs) job['execution_status'] = status_info job['workflow'] = { 'prompt': prompt, @@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: - count += 1 - - if not isinstance(item, dict): + normalized = normalize_output_item(item) + if normalized is None: continue - if preview_output is None and is_previewable(media_type, item): + count += 1 + + if preview_output is not None: + continue + + if isinstance(normalized, dict) and is_previewable(media_type, normalized): enriched = { - **item, + **normalized, 'nodeId': node_id, - 'mediaType': media_type } - if item.get('type') == 'output': + if 'mediaType' not in normalized: + enriched['mediaType'] = media_type + if normalized.get('type') == 'output': preview_output = enriched elif fallback_preview is None: fallback_preview = enriched diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 4d2f9ed36..83c36fe48 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -5,8 +5,11 @@ from comfy_execution.jobs import ( is_previewable, normalize_queue_item, normalize_history_item, + normalize_output_item, + normalize_outputs, get_outputs_summary, apply_sorting, + has_3d_extension, ) @@ -35,8 +38,8 @@ class TestIsPreviewable: """Unit tests for is_previewable()""" def test_previewable_media_types(self): - """Images, video, audio media types should be previewable.""" - for media_type in ['images', 'video', 'audio']: + """Images, video, audio, 3d media types should be previewable.""" + for media_type in ['images', 'video', 'audio', '3d']: assert is_previewable(media_type, {}) is True def test_non_previewable_media_types(self): @@ -46,7 +49,7 @@ class TestIsPreviewable: def test_3d_extensions_previewable(self): """3D file extensions should be previewable regardless of media_type.""" - for ext in ['.obj', '.fbx', '.gltf', '.glb']: + for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']: item = {'filename': f'model{ext}'} assert is_previewable('files', item) is True @@ -160,7 +163,7 @@ class TestGetOutputsSummary: def test_3d_files_previewable(self): """3D file extensions should be previewable.""" - for ext in ['.obj', '.fbx', '.gltf', '.glb']: + for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']: outputs = { 'node1': { 'files': [{'filename': f'model{ext}', 'type': 'output'}] @@ -192,6 +195,64 @@ class TestGetOutputsSummary: assert preview['mediaType'] == 'images' assert preview['subfolder'] == 'outputs' + def test_string_3d_filename_creates_preview(self): + """String items with 3D extensions should synthesize a preview (Preview3D node output). + Only the .glb counts — nulls and non-file strings are excluded.""" + outputs = { + 'node1': { + 'result': ['preview3d_abc123.glb', None, None] + } + } + count, preview = get_outputs_summary(outputs) + assert count == 1 + assert preview is not None + assert preview['filename'] == 'preview3d_abc123.glb' + assert preview['mediaType'] == '3d' + assert preview['nodeId'] == 'node1' + assert preview['type'] == 'output' + + def test_string_non_3d_filename_no_preview(self): + """String items without 3D extensions should not create a preview.""" + outputs = { + 'node1': { + 'result': ['data.json', None] + } + } + count, preview = get_outputs_summary(outputs) + assert count == 0 + assert preview is None + + def test_string_3d_filename_used_as_fallback(self): + """String 3D preview should be used when no dict items are previewable.""" + outputs = { + 'node1': { + 'latents': [{'filename': 'latent.safetensors'}], + }, + 'node2': { + 'result': ['model.glb', None] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None + assert preview['filename'] == 'model.glb' + assert preview['mediaType'] == '3d' + + +class TestHas3DExtension: + """Unit tests for has_3d_extension()""" + + def test_recognized_extensions(self): + for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']: + assert has_3d_extension(f'model{ext}') is True + + def test_case_insensitive(self): + assert has_3d_extension('MODEL.GLB') is True + assert has_3d_extension('Scene.GLTF') is True + + def test_non_3d_extensions(self): + for name in ['photo.png', 'video.mp4', 'data.json', 'model']: + assert has_3d_extension(name) is False + class TestApplySorting: """Unit tests for apply_sorting()""" @@ -395,3 +456,142 @@ class TestNormalizeHistoryItem: 'prompt': {'nodes': {'1': {}}}, 'extra_data': {'create_time': 1234567890, 'client_id': 'abc'}, } + + def test_include_outputs_normalizes_3d_strings(self): + """Detail view should transform string 3D filenames into file output dicts.""" + history_item = { + 'prompt': ( + 5, + 'prompt-3d', + {'nodes': {}}, + {'create_time': 1234567890}, + ['node1'], + ), + 'status': {'status_str': 'success', 'completed': True, 'messages': []}, + 'outputs': { + 'node1': { + 'result': ['preview3d_abc123.glb', None, None] + } + }, + } + job = normalize_history_item('prompt-3d', history_item, include_outputs=True) + + assert job['outputs_count'] == 1 + result_items = job['outputs']['node1']['result'] + assert len(result_items) == 1 + assert result_items[0] == { + 'filename': 'preview3d_abc123.glb', + 'type': 'output', + 'subfolder': '', + 'mediaType': '3d', + } + + def test_include_outputs_preserves_dict_items(self): + """Detail view normalization should pass dict items through unchanged.""" + history_item = { + 'prompt': ( + 5, + 'prompt-img', + {'nodes': {}}, + {'create_time': 1234567890}, + ['node1'], + ), + 'status': {'status_str': 'success', 'completed': True, 'messages': []}, + 'outputs': { + 'node1': { + 'images': [ + {'filename': 'photo.png', 'type': 'output', 'subfolder': ''}, + ] + } + }, + } + job = normalize_history_item('prompt-img', history_item, include_outputs=True) + + assert job['outputs_count'] == 1 + assert job['outputs']['node1']['images'] == [ + {'filename': 'photo.png', 'type': 'output', 'subfolder': ''}, + ] + + +class TestNormalizeOutputItem: + """Unit tests for normalize_output_item()""" + + def test_none_returns_none(self): + assert normalize_output_item(None) is None + + def test_string_3d_extension_synthesizes_dict(self): + result = normalize_output_item('model.glb') + assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'} + + def test_string_non_3d_extension_returns_none(self): + assert normalize_output_item('data.json') is None + + def test_string_no_extension_returns_none(self): + assert normalize_output_item('camera_info_string') is None + + def test_dict_passes_through(self): + item = {'filename': 'test.png', 'type': 'output'} + assert normalize_output_item(item) is item + + def test_other_types_return_none(self): + assert normalize_output_item(42) is None + assert normalize_output_item(True) is None + + +class TestNormalizeOutputs: + """Unit tests for normalize_outputs()""" + + def test_empty_outputs(self): + assert normalize_outputs({}) == {} + + def test_dict_items_pass_through(self): + outputs = { + 'node1': { + 'images': [{'filename': 'a.png', 'type': 'output'}], + } + } + result = normalize_outputs(outputs) + assert result == outputs + + def test_3d_string_synthesized(self): + outputs = { + 'node1': { + 'result': ['model.glb', None, None], + } + } + result = normalize_outputs(outputs) + assert result == { + 'node1': { + 'result': [ + {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}, + ], + } + } + + def test_animated_key_preserved(self): + outputs = { + 'node1': { + 'images': [{'filename': 'a.png', 'type': 'output'}], + 'animated': [True], + } + } + result = normalize_outputs(outputs) + assert result['node1']['animated'] == [True] + + def test_non_dict_node_outputs_preserved(self): + outputs = {'node1': 'unexpected_value'} + result = normalize_outputs(outputs) + assert result == {'node1': 'unexpected_value'} + + def test_none_items_filtered_but_other_types_preserved(self): + outputs = { + 'node1': { + 'result': ['data.json', None, [1, 2, 3]], + } + } + result = normalize_outputs(outputs) + assert result == { + 'node1': { + 'result': ['data.json', [1, 2, 3]], + } + } From dbe70b6821994ce92d9cf211cc685862d0b6c0ca Mon Sep 17 00:00:00 2001 From: AustinMroz Date: Tue, 10 Feb 2026 14:42:21 -0800 Subject: [PATCH 084/215] Add a VideoSlice node (#12107) * Base TrimVideo implementation * Raise error if as_trimmed call fails * Bigger max start_time, tooltips, and formatting * Count packets unless codec has subframes * Remove incorrect nested decode * Add null check for audio streams * Support non-strict duration * Added strict_duration bool to node definition * Empty commit for approval * Fix duration * Support 5.1 audio layout on save --------- Co-authored-by: Jedrzej Kosinski --- comfy_api/latest/_input/video_types.py | 15 ++ comfy_api/latest/_input_impl/video_types.py | 201 ++++++++++++++------ comfy_extras/nodes_video.py | 51 +++++ 3 files changed, 207 insertions(+), 60 deletions(-) diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index e634a0311..451e9526e 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -34,6 +34,21 @@ class VideoInput(ABC): """ pass + @abstractmethod + def as_trimmed( + self, + start_time: float | None = None, + duration: float | None = None, + strict_duration: bool = False, + ) -> VideoInput | None: + """ + Create a new VideoInput which is trimmed to have the corresponding start_time and duration + + Returns: + A new VideoInput, or None if the result would have negative duration + """ + pass + def get_stream_source(self) -> Union[str, io.BytesIO]: """ Get a streamable source for the video. This allows processing without diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 1405d0b81..3463ed1c9 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -6,6 +6,7 @@ from typing import Optional from .._input import AudioInput, VideoInput import av import io +import itertools import json import numpy as np import math @@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None: formats = container_format.split(",") return formats[0] - def get_open_write_kwargs( dest: str | io.BytesIO, container_format: str, to_format: str | None ) -> dict: @@ -57,12 +57,14 @@ class VideoFromFile(VideoInput): Class representing video input from a file. """ - def __init__(self, file: str | io.BytesIO): + def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0): """ Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object containing the file contents. """ self.__file = file + self.__start_time = start_time + self.__duration = duration def get_stream_source(self) -> str | io.BytesIO: """ @@ -96,6 +98,16 @@ class VideoFromFile(VideoInput): Returns: Duration in seconds """ + raw_duration = self._get_raw_duration() + if self.__start_time < 0: + duration_from_start = min(raw_duration, -self.__start_time) + else: + duration_from_start = raw_duration - self.__start_time + if self.__duration: + return min(self.__duration, duration_from_start) + return duration_from_start + + def _get_raw_duration(self) -> float: if isinstance(self.__file, io.BytesIO): self.__file.seek(0) with av.open(self.__file, mode="r") as container: @@ -113,9 +125,13 @@ class VideoFromFile(VideoInput): if video_stream and video_stream.average_rate: frame_count = 0 container.seek(0) - for packet in container.demux(video_stream): - for _ in packet.decode(): - frame_count += 1 + frame_iterator = ( + container.decode(video_stream) + if video_stream.codec.capabilities & 0x100 + else container.demux(video_stream) + ) + for packet in frame_iterator: + frame_count += 1 if frame_count > 0: return float(frame_count / video_stream.average_rate) @@ -131,36 +147,54 @@ class VideoFromFile(VideoInput): with av.open(self.__file, mode="r") as container: video_stream = self._get_first_video_stream(container) - # 1. Prefer the frames field if available - if video_stream.frames and video_stream.frames > 0: + # 1. Prefer the frames field if available and usable + if ( + video_stream.frames + and video_stream.frames > 0 + and not self.__start_time + and not self.__duration + ): return int(video_stream.frames) # 2. Try to estimate from duration and average_rate using only metadata - if container.duration is not None and video_stream.average_rate: - duration_seconds = float(container.duration / av.time_base) - estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) - if estimated_frames > 0: - return estimated_frames - if ( getattr(video_stream, "duration", None) is not None and getattr(video_stream, "time_base", None) is not None and video_stream.average_rate ): - duration_seconds = float(video_stream.duration * video_stream.time_base) + raw_duration = float(video_stream.duration * video_stream.time_base) + if self.__start_time < 0: + duration_from_start = min(raw_duration, -self.__start_time) + else: + duration_from_start = raw_duration - self.__start_time + duration_seconds = min(self.__duration, duration_from_start) estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) if estimated_frames > 0: return estimated_frames # 3. Last resort: decode frames and count them (streaming) - frame_count = 0 - container.seek(0) - for packet in container.demux(video_stream): - for _ in packet.decode(): - frame_count += 1 - - if frame_count == 0: - raise ValueError(f"Could not determine frame count for file '{self.__file}'") + if self.__start_time < 0: + start_time = max(self._get_raw_duration() + self.__start_time, 0) + else: + start_time = self.__start_time + frame_count = 1 + start_pts = int(start_time / video_stream.time_base) + end_pts = int((start_time + self.__duration) / video_stream.time_base) + container.seek(start_pts, stream=video_stream) + frame_iterator = ( + container.decode(video_stream) + if video_stream.codec.capabilities & 0x100 + else container.demux(video_stream) + ) + for frame in frame_iterator: + if frame.pts >= start_pts: + break + else: + raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}") + for frame in frame_iterator: + if frame.pts >= end_pts: + break + frame_count += 1 return frame_count def get_frame_rate(self) -> Fraction: @@ -199,9 +233,21 @@ class VideoFromFile(VideoInput): return container.format.name def get_components_internal(self, container: InputContainer) -> VideoComponents: + video_stream = self._get_first_video_stream(container) + if self.__start_time < 0: + start_time = max(self._get_raw_duration() + self.__start_time, 0) + else: + start_time = self.__start_time # Get video frames frames = [] - for frame in container.decode(video=0): + start_pts = int(start_time / video_stream.time_base) + end_pts = int((start_time + self.__duration) / video_stream.time_base) + container.seek(start_pts, stream=video_stream) + for frame in container.decode(video_stream): + if frame.pts < start_pts: + continue + if self.__duration and frame.pts >= end_pts: + break img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) frames.append(img) @@ -209,31 +255,44 @@ class VideoFromFile(VideoInput): images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) # Get frame rate - video_stream = next(s for s in container.streams if s.type == 'video') - frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) + frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) # Get audio if available audio = None - try: - container.seek(0) # Reset the container to the beginning - for stream in container.streams: - if stream.type != 'audio': - continue - assert isinstance(stream, av.AudioStream) - audio_frames = [] - for packet in container.demux(stream): - for frame in packet.decode(): - assert isinstance(frame, av.AudioFrame) - audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) - if len(audio_frames) > 0: - audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) - audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) - audio = AudioInput({ - "waveform": audio_tensor, - "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, - }) - except StopIteration: - pass # No audio stream + container.seek(start_pts, stream=video_stream) + # Use last stream for consistency + if len(container.streams.audio): + audio_stream = container.streams.audio[-1] + audio_frames = [] + resample = av.audio.resampler.AudioResampler(format='fltp').resample + frames = itertools.chain.from_iterable( + map(resample, container.decode(audio_stream)) + ) + + has_first_frame = False + for frame in frames: + offset_seconds = start_time - frame.pts * audio_stream.time_base + to_skip = int(offset_seconds * audio_stream.sample_rate) + if to_skip < frame.samples: + has_first_frame = True + break + if has_first_frame: + audio_frames.append(frame.to_ndarray()[..., to_skip:]) + + for frame in frames: + if frame.time > start_time + self.__duration: + break + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + if self.__duration: + audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)] + + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1, + }) metadata = container.metadata return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) @@ -250,7 +309,7 @@ class VideoFromFile(VideoInput): path: str | io.BytesIO, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, ): if isinstance(self.__file, io.BytesIO): self.__file.seek(0) # Reset the BytesIO object to the beginning @@ -262,15 +321,14 @@ class VideoFromFile(VideoInput): reuse_streams = False if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: reuse_streams = False + if self.__start_time or self.__duration: + reuse_streams = False if not reuse_streams: components = self.get_components_internal(container) video = VideoFromComponents(components) return video.save_to( - path, - format=format, - codec=codec, - metadata=metadata + path, format=format, codec=codec, metadata=metadata ) streams = container.streams @@ -304,10 +362,21 @@ class VideoFromFile(VideoInput): output_container.mux(packet) def _get_first_video_stream(self, container: InputContainer): - video_stream = next((s for s in container.streams if s.type == "video"), None) - if video_stream is None: - raise ValueError(f"No video stream found in file '{self.__file}'") - return video_stream + if len(container.streams.video): + return container.streams.video[0] + raise ValueError(f"No video stream found in file '{self.__file}'") + + def as_trimmed( + self, start_time: float = 0, duration: float = 0, strict_duration: bool = True + ) -> VideoInput | None: + trimmed = VideoFromFile( + self.get_stream_source(), + start_time=start_time + self.__start_time, + duration=duration, + ) + if trimmed.get_duration() < duration and strict_duration: + return None + return trimmed class VideoFromComponents(VideoInput): @@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput): return VideoComponents( images=self.__components.images, audio=self.__components.audio, - frame_rate=self.__components.frame_rate + frame_rate=self.__components.frame_rate, ) def save_to( @@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput): path: str, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, ): if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") @@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput): audio_stream: Optional[av.AudioStream] = None if self.__components.audio: audio_sample_rate = int(self.__components.audio['sample_rate']) - audio_stream = output.add_stream('aac', rate=audio_sample_rate) + waveform = self.__components.audio['waveform'] + waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') + audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout) # Encode video for i, frame in enumerate(self.__components.images): @@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput): output.mux(packet) if audio_stream and self.__components.audio: - waveform = self.__components.audio['waveform'] - waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] - frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout) frame.sample_rate = audio_sample_rate frame.pts = 0 output.mux(audio_stream.encode(frame)) # Flush encoder output.mux(audio_stream.encode(None)) + + def as_trimmed( + self, + start_time: float | None = None, + duration: float | None = None, + strict_duration: bool = True, + ) -> VideoInput | None: + if self.get_duration() < start_time + duration: + return None + #TODO Consider tracking duration and trimming at time of save? + return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index ccf7b63d3..cd765a7c1 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode): return True +class VideoSlice(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Video Slice", + display_name="Video Slice", + search_aliases=[ + "trim video duration", + "skip first frames", + "frame load cap", + "start time", + ], + category="image/video", + inputs=[ + io.Video.Input("video"), + io.Float.Input( + "start_time", + default=0.0, + max=1e5, + min=-1e5, + step=0.001, + tooltip="Start time in seconds", + ), + io.Float.Input( + "duration", + default=0.0, + min=0.0, + step=0.001, + tooltip="Duration in seconds, or 0 for unlimited duration", + ), + io.Boolean.Input( + "strict_duration", + default=False, + tooltip="If True, when the specified duration is not possible, an error will be raised.", + ), + ], + outputs=[ + io.Video.Output(), + ], + ) + + @classmethod + def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput: + trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration) + if trimmed is not None: + return io.NodeOutput(trimmed) + raise ValueError( + f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}" + ) + class VideoExtension(ComfyExtension): @override @@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension): CreateVideo, GetVideoComponents, LoadVideo, + VideoSlice, ] async def comfy_entrypoint() -> VideoExtension: From cdcf4119b3e826bd69fa986772485fb5b44a54cd Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 11 Feb 2026 10:45:19 +0800 Subject: [PATCH 085/215] [Trainer] training with proper offloading (#12189) * Fix bypass dtype/device moving * Force offloading mode for training * training context var * offloading implementation in training node * fix wrong input type * Support bypass load lora model, correct adapter/offloading handling --- comfy/ldm/flux/math.py | 39 +++++--- comfy/model_management.py | 5 + comfy/sampler_helpers.py | 16 +++- comfy/weight_adapter/bypass.py | 20 ++-- comfy_extras/nodes_train.py | 162 ++++++++++++++++++++++++++++----- 5 files changed, 196 insertions(+), 46 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index f9597de5b..5e764bb46 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: return out.to(dtype=torch.float32, device=pos.device) +def _apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + + return x_out.reshape(*x.shape).type_as(x) + + +def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + + try: import comfy.quant_ops - apply_rope = comfy.quant_ops.ck.apply_rope - apply_rope1 = comfy.quant_ops.ck.apply_rope1 + q_apply_rope = comfy.quant_ops.ck.apply_rope + q_apply_rope1 = comfy.quant_ops.ck.apply_rope1 + def apply_rope(xq, xk, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope(xq, xk, freqs_cis) + else: + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + def apply_rope1(x, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope1(x, freqs_cis) + else: + return q_apply_rope1(x, freqs_cis) except: logging.warning("No comfy kitchen, using old apply_rope functions.") - def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) - - return x_out.reshape(*x.shape).type_as(x) - - def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + apply_rope = _apply_rope + apply_rope1 = _apply_rope1 diff --git a/comfy/model_management.py b/comfy/model_management.py index 6018c1ab6..304931eb0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,6 +55,11 @@ cpu_state = CPUState.GPU total_vram = 0 + +# Training Related State +in_training = False + + def get_supported_float8_types(): float8_types = [] try: diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 9134e6d71..1f75f2ba7 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? - memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) + if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load + memory_required = 1e20 + minimum_memory_required = None + else: + memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) + memory_required += inference_memory + minimum_memory_required += inference_memory + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py index d4aaf98ca..b9d5ec7d9 100644 --- a/comfy/weight_adapter/bypass.py +++ b/comfy/weight_adapter/bypass.py @@ -21,6 +21,7 @@ from typing import Optional, Union import torch import torch.nn as nn +import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase from comfy.patcher_extension import PatcherInjection @@ -181,18 +182,21 @@ class BypassForwardHook: ) return # Already injected - # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward - device = None + # Move adapter weights to compute device (GPU) + # Use get_torch_device() instead of module.weight.device because + # with offloading, module weights may be on CPU while compute happens on GPU + device = comfy.model_management.get_torch_device() + + # Get dtype from module weight if available dtype = None if hasattr(self.module, "weight") and self.module.weight is not None: - device = self.module.weight.device dtype = self.module.weight.dtype - elif hasattr(self.module, "W_q"): # Quantized layers might use different attr - device = self.module.W_q.device - dtype = self.module.W_q.dtype - if device is not None: - self._move_adapter_weights_to_device(device, dtype) + # Only use dtype if it's a standard float type, not quantized + if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = None + + self._move_adapter_weights_to_device(device, dtype) self.original_forward = self.module.forward self.module.forward = self._bypass_forward diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 024a89391..630eedc9f 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -4,6 +4,7 @@ import os import numpy as np import safetensors import torch +import torch.nn as nn import torch.utils.checkpoint from tqdm.auto import trange from PIL import Image, ImageDraw, ImageFont @@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): """ CFGGuider with modifications for training specific logic """ + + def __init__(self, *args, offloading=False, **kwargs): + super().__init__(*args, **kwargs) + self.offloading = offloading + def outer_sample( self, noise, @@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): noise.shape, self.conds, self.model_options, - force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded + force_full_load=not self.offloading, + force_offload=self.offloading, ) ) + torch.cuda.empty_cache() device = self.model_patcher.load_device if denoise_mask is not None: @@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward( return result -def patch(m): +def find_modules_at_depth( + model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None +) -> list[nn.Module]: + """ + Find modules at a specific depth level for gradient checkpointing. + + Args: + model: The model to search + depth: Target depth level (1 = top-level blocks, 2 = their children, etc.) + result: Accumulator for results + current_depth: Current recursion depth + name: Current module name for logging + + Returns: + List of modules at the target depth + """ + if result is None: + result = [] + name = name or "root" + + # Skip container modules (they don't have meaningful forward) + is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict)) + has_forward = hasattr(model, "forward") and not is_container + + if has_forward: + current_depth += 1 + if current_depth == depth: + result.append(model) + logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})") + return result + + # Recurse into children + for next_name, child in model.named_children(): + find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}") + + return result + + +class OffloadCheckpointFunction(torch.autograd.Function): + """ + Gradient checkpointing that works with weight offloading. + + Forward: no_grad -> compute -> weights can be freed + Backward: enable_grad -> recompute -> backward -> weights can be freed + + For single input, single output modules (Linear, Conv*). + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, forward_fn): + ctx.save_for_backward(x) + ctx.forward_fn = forward_fn + with torch.no_grad(): + return forward_fn(x) + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + x, = ctx.saved_tensors + forward_fn = ctx.forward_fn + + # Clear context early + ctx.forward_fn = None + + with torch.enable_grad(): + x_detached = x.detach().requires_grad_(True) + y = forward_fn(x_detached) + y.backward(grad_out) + grad_x = x_detached.grad + + # Explicit cleanup + del y, x_detached, forward_fn + + return grad_x, None + + +def patch(m, offloading=False): if not hasattr(m, "forward"): return org_forward = m.forward - def fwd(args, kwargs): - return org_forward(*args, **kwargs) + # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output) + if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + def checkpointing_fwd(x): + return OffloadCheckpointFunction.apply(x, org_forward) + # Branch 2: Others -> standard checkpoint + else: + def fwd(args, kwargs): + return org_forward(*args, **kwargs) - def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) m.org_forward = org_forward m.forward = checkpointing_fwd @@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode): default=True, tooltip="Use gradient checkpointing for training.", ), + io.Int.Input( + "checkpoint_depth", + default=1, + min=1, + max=5, + tooltip="Depth level for gradient checkpointing.", + ), + io.Boolean.Input( + "offloading", + default=False, + tooltip="Depth level for gradient checkpointing.", + ), io.Combo.Input( "existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], @@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode): lora_dtype, algorithm, gradient_checkpointing, + checkpoint_depth, + offloading, existing_lora, bucket_mode, bypass_mode, @@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] + offloading = offloading[0] + checkpoint_depth = checkpoint_depth[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] @@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode): # Setup gradient checkpointing if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward( - mp.model.diffusion_model - ): - patch(m) + modules_to_patch = find_modules_at_depth( + mp.model.diffusion_model, depth=checkpoint_depth + ) + logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}") + for m in modules_to_patch: + patch(m, offloading=offloading) torch.cuda.empty_cache() # With force_full_load=False we should be able to have offloading # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( - [mp], memory_required=1e20, force_full_load=True + [mp], memory_required=1e20, force_full_load=not offloading ) torch.cuda.empty_cache() @@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode): ) # Setup guider - guider = TrainGuider(mp) + guider = TrainGuider(mp, offloading=offloading) guider.set_conds(positive) # Inject bypass hooks if bypass mode is enabled @@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode): # Run training loop try: + comfy.model_management.in_training = True _run_training_loop( guider, train_sampler, @@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + comfy.model_management.in_training = False # Eject bypass hooks if they were injected if bypass_injections is not None: for injection in bypass_injections: @@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode): unpatch(m) del train_sampler, optimizer - # Finalize adapters + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype).detach() + for adapter in all_weight_adapters: adapter.requires_grad_(False) - - for param in lora_sd: - lora_sd[param] = lora_sd[param].to(lora_dtype) + del adapter + del all_weight_adapters # mp in train node is highly specialized for training # use it in inference will result in bad behavior so we don't return it return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode):# +class LoraModelLoader(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( @@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):# max=100.0, tooltip="How strongly to modify the diffusion model. This value can be negative.", ), + io.Boolean.Input( + "bypass", + default=False, + tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.", + ), ], outputs=[ io.Model.Output( @@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):# ) @classmethod - def execute(cls, model, lora, strength_model): + def execute(cls, model, lora, strength_model, bypass=False): if strength_model == 0: return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models( - model, None, lora, strength_model, 0 - ) + if bypass: + model_lora, _ = comfy.sd.load_bypass_lora_for_models( + model, None, lora, strength_model, 0 + ) + else: + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) return io.NodeOutput(model_lora) From 76a7fa96dbdc2eda89218601fe3aed5997df055f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 10 Feb 2026 19:04:32 -0800 Subject: [PATCH 086/215] Make built in lora training work on anima. (#12402) --- comfy/ldm/anima/model.py | 16 ++++++++++++++-- comfy/model_base.py | 12 ++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py index 2e6ed58fa..6fb51c4a4 100644 --- a/comfy/ldm/anima/model.py +++ b/comfy/ldm/anima/model.py @@ -195,8 +195,20 @@ class Anima(MiniTrainDIT): super().__init__(*args, **kwargs) self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) - def preprocess_text_embeds(self, text_embeds, text_ids): + def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None): if text_ids is not None: - return self.llm_adapter(text_embeds, text_ids) + out = self.llm_adapter(text_embeds, text_ids) + if t5xxl_weights is not None: + out = out * t5xxl_weights + + if out.shape[1] < 512: + out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1])) + return out else: return text_embeds + + def forward(self, x, timesteps, context, **kwargs): + t5xxl_ids = kwargs.pop("t5xxl_ids", None) + if t5xxl_ids is not None: + context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None)) + return super().forward(x, timesteps, context, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index 858789b30..4a74cb1ce 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1160,12 +1160,16 @@ class Anima(BaseModel): device = kwargs["device"] if cross_attn is not None: if t5xxl_ids is not None: - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) if t5xxl_weights is not None: - cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + t5xxl_ids = t5xxl_ids.unsqueeze(0) + + if torch.is_inference_mode_enabled(): # if not we are training + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) + else: + out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) + out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) - if cross_attn.shape[1] < 512: - cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out From 2c7cef4a23c08e3f02a33c693d927158a15a11f6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 11 Feb 2026 20:51:49 +0200 Subject: [PATCH 087/215] fix(api-nodes): retry on connection errors during polling instead of aborting (#12393) --- comfy_api_nodes/util/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 8a1259506..391748e7a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -143,9 +143,9 @@ async def poll_op( poll_interval: float = 5.0, max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, - max_retries_per_poll: int = 3, + max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, - retry_backoff_per_poll: float = 2.0, + retry_backoff_per_poll: float = 1.4, estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, @@ -240,9 +240,9 @@ async def poll_op_raw( poll_interval: float = 5.0, max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, - max_retries_per_poll: int = 3, + max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, - retry_backoff_per_poll: float = 2.0, + retry_backoff_per_poll: float = 1.4, estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, From 4993411fd9a43d642971925272c3748d9e058131 Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Wed, 11 Feb 2026 11:15:13 -0800 Subject: [PATCH 088/215] Dispatch desktop auto-bump when a ComfyUI release is published (#12398) * Dispatch desktop auto-bump on ComfyUI release publish * Fix release webhook secret checks in step conditions * Require desktop dispatch token in release webhook * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Luke Mino-Altherr Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Jedrzej Kosinski --- .github/workflows/release-webhook.yml | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml index 6fceb7560..737e4c488 100644 --- a/.github/workflows/release-webhook.yml +++ b/.github/workflows/release-webhook.yml @@ -7,6 +7,8 @@ on: jobs: send-webhook: runs-on: ubuntu-latest + env: + DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }} steps: - name: Send release webhook env: @@ -106,3 +108,37 @@ jobs: --fail --silent --show-error echo "✅ Release webhook sent successfully" + + - name: Send repository dispatch to desktop + env: + DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + RELEASE_URL: ${{ github.event.release.html_url }} + run: | + set -euo pipefail + + if [ -z "${DISPATCH_TOKEN:-}" ]; then + echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set." + exit 1 + fi + + PAYLOAD="$(jq -n \ + --arg release_tag "$RELEASE_TAG" \ + --arg release_url "$RELEASE_URL" \ + '{ + event_type: "comfyui_release_published", + client_payload: { + release_tag: $release_tag, + release_url: $release_url + } + }')" + + curl -fsSL \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${DISPATCH_TOKEN}" \ + https://api.github.com/repos/Comfy-Org/desktop/dispatches \ + -d "$PAYLOAD" + + echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop" From 2b7cc7e3b69127a81b9232d4e8305eb678fa3d0c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 11 Feb 2026 21:30:19 +0200 Subject: [PATCH 089/215] [API Nodes] enable Magnific Upscalers (#12179) * feat(api-nodes): enable Magnific Upscalers * update price badges --------- Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_magnific.py | 62 +++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index 013e71cc8..83a581c5d 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -30,6 +30,30 @@ from comfy_api_nodes.util import ( validate_image_dimensions, ) +_EUR_TO_USD = 1.19 + + +def _tier_price_eur(megapixels: float) -> float: + """Price in EUR for a single Magnific upscaling step based on input megapixels.""" + if megapixels <= 1.3: + return 0.143 + if megapixels <= 3.0: + return 0.286 + if megapixels <= 6.4: + return 0.429 + return 1.716 + + +def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float: + """Calculate total Magnific upscale price in USD for given input dimensions and scale factor.""" + num_steps = int(math.log2(scale)) + total_eur = 0.0 + pixels = width * height + for _ in range(num_steps): + total_eur += _tier_price_eur(pixels / 1_000_000) + pixels *= 4 + return round(total_eur * _EUR_TO_USD, 2) + class MagnificImageUpscalerCreativeNode(IO.ComfyNode): @classmethod @@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]), expr=""" ( - $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; - {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + $ad := widgets.auto_downscale; + $mins := $ad + ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515} + : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; + $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187}; + { + "type": "range_usd", + "min_usd": $lookup($mins, widgets.scale_factor), + "max_usd": $lookup($maxs, widgets.scale_factor), + "format": { "approximate": true } + } ) """, ), @@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): f"Use a smaller input image or lower scale factor." ) + final_height, final_width = get_image_dimensions(image) + actual_scale = int(scale_factor.rstrip("x")) + price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale) + initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), @@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), response_model=TaskResponse, status_extractor=lambda x: x.status, + price_extractor=lambda _: price_usd, poll_interval=10.0, max_poll_attempts=480, ) @@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), expr=""" ( - $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; - {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; + $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06}; + { + "type": "range_usd", + "min_usd": $lookup($mins, widgets.scale_factor), + "max_usd": $lookup($maxs, widgets.scale_factor), + "format": { "approximate": true } + } ) """, ), @@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): f"Use a smaller input image or lower scale factor." ) + final_height, final_width = get_image_dimensions(image) + price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale) + initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), @@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), response_model=TaskResponse, status_extractor=lambda x: x.status, + price_extractor=lambda _: price_usd, poll_interval=10.0, max_poll_attempts=480, ) @@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ - # MagnificImageUpscalerCreativeNode, - # MagnificImageUpscalerPreciseV2Node, + MagnificImageUpscalerCreativeNode, + MagnificImageUpscalerPreciseV2Node, MagnificImageStyleTransferNode, MagnificImageRelightNode, MagnificImageSkinEnhancerNode, From d297a749a2fa3a34ebff898797feef161bcd64c6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 11 Feb 2026 11:50:16 -0800 Subject: [PATCH 090/215] dynamic_vram: Fix windows Aimdo crash + Fix LLM performance (#12408) * model_management: lazy-cache aimdo_tensor These tensors cosntructed from aimdo-allocations are CPU expensive to make on the pytorch side. Add a cache version that will be valid with signature match to fast path past whatever torch is doing. * dynamic_vram: Minimize fast path CPU work Move as much as possible inside the not resident if block and cache the formed weight and bias rather than the flat intermediates. In extreme layer weight rates this adds up. --- comfy/model_management.py | 8 ++++++-- comfy/model_patcher.py | 2 -- comfy/ops.py | 21 ++++++++++++++------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 304931eb0..38c3e482b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1213,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0] - if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + v_tensor = weight._v_tensor + else: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + weight._v_tensor = v_tensor weight._v_signature = signature #Send it over v_tensor.copy_(weight, non_blocking=non_blocking) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 19c9031ea..224e218e3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None and not hasattr(m, "_v"): m._v = vbar.alloc(v_weight_size) - m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to) allocated_size += v_weight_size else: @@ -1557,7 +1556,6 @@ class ModelPatcherDynamic(ModelPatcher): weight_size = geometry.numel() * geometry.element_size() if vbar is not None and not hasattr(weight, "_v"): weight._v = vbar.alloc(weight_size) - weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to) weight._model_dtype = model_dtype allocated_size += weight_size vbar.set_watermark_limit(allocated_size) diff --git a/comfy/ops.py b/comfy/ops.py index 33803b223..688937e43 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): offload_stream = None xfer_dest = None - cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - if signature is not None: - xfer_dest = s._v_tensor resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + if signature is not None: + if resident: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if not resident: + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None xfer_source = [ s.weight, s.bias ] @@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) - weight = params[0] - bias = params[1] + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + if signature is not None: + s._v_weight = weight + s._v_bias = bias + s._v_signature=signature def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) @@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu weight = post_cast(s, "weight", weight, dtype, resident, update_weight) if s.bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) - s._v_signature=signature #FIXME: weird offload return protocol return weight, bias, (offload_stream, device if signature is not None else None, None) From 2a4328d639810858aa625c7bfedb974a13a57abe Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 11 Feb 2026 11:53:42 -0800 Subject: [PATCH 091/215] ace15: Use dynamic_vram friendly trange (#12409) Factor out the ksampler trange and use it in ACE LLM to prevent the silent stall at 0 and rate distortion due to first-step model load. --- comfy/k_diffusion/sampling.py | 32 ++------------------------------ comfy/text_encoders/ace15.py | 3 +-- comfy/utils.py | 27 +++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c0c51d51a..6978eb717 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,12 +1,11 @@ import math -import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange as trange_, tqdm +from tqdm.auto import tqdm from . import utils from . import deis @@ -15,34 +14,7 @@ import comfy.model_patcher import comfy.model_sampling import comfy.memory_management - - -def trange(*args, **kwargs): - if comfy.memory_management.aimdo_allocator is None: - return trange_(*args, **kwargs) - - pbar = trange_(*args, **kwargs, smoothing=1.0) - pbar._i = 0 - pbar.set_postfix_str(" Model Initializing ... ") - - _update = pbar.update - - def warmup_update(n=1): - pbar._i += 1 - if pbar._i == 1: - pbar.i1_time = time.time() - pbar.set_postfix_str(" Model Initialization complete! ") - elif pbar._i == 2: - #bring forward the effective start time based the the diff between first and second iteration - #to attempt to remove load overhead from the final step rate estimate. - pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) - pbar.set_postfix_str("") - - _update(n) - - pbar.update = warmup_update - return pbar - +from comfy.utils import model_trange as trange def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 73697b3c1..b8198a820 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,7 +3,6 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math -from tqdm.auto import trange import yaml import comfy.utils @@ -52,7 +51,7 @@ def sample_manual_loop_no_classes( progress_bar = comfy.utils.ProgressBar(max_new_tokens) - for step in trange(max_new_tokens, desc="LM sampling"): + for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"): outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) next_token_logits = model.transformer.logits(outputs[0])[:, -1] past_key_values = outputs[2] diff --git a/comfy/utils.py b/comfy/utils.py index edd80cebe..e0a94e2e1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -27,6 +27,7 @@ from PIL import Image import logging import itertools from torch.nn.functional import interpolate +from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args, enables_dynamic_vram import json @@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) +def model_trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator is None: + return trange(*args, **kwargs) + + pbar = trange(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + PROGRESS_BAR_ENABLED = True def set_progress_bar_enabled(enabled): global PROGRESS_BAR_ENABLED From 3fe61cedda090c744dcf6f579ed48744fa66ef5f Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 11 Feb 2026 11:54:02 -0800 Subject: [PATCH 092/215] model_patcher: guard against none model_dtype (#12410) Handle the case where the _model_dtype exists but is none with the intended fallback. --- comfy/model_patcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 224e218e3..f278fccac 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher): setattr(m, param_key + "_function", weight_function) geometry = weight if not isinstance(weight, QuantizedTensor): - model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype weight._model_dtype = model_dtype geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) return comfy.memory_management.vram_aligned_size(geometry) @@ -1551,7 +1551,7 @@ class ModelPatcherDynamic(ModelPatcher): weight.seed_key = key set_dirty(weight, dirty) geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) weight_size = geometry.numel() * geometry.element_size() if vbar is not None and not hasattr(weight, "_v"): From e5ae670a4016d3698a806e7f840fcecc50639848 Mon Sep 17 00:00:00 2001 From: askmyteapot <62238146+askmyteapot@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:28:48 +1000 Subject: [PATCH 093/215] Update ace15.py to allow min_p sampling (#12373) --- comfy/text_encoders/ace15.py | 15 ++++++++++++--- comfy_extras/nodes_ace.py | 5 +++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index b8198a820..0fdd4669f 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -16,6 +16,7 @@ def sample_manual_loop_no_classes( temperature: float = 0.85, top_p: float = 0.9, top_k: int = None, + min_p: float = 0.000, seed: int = 1, min_tokens: int = 1, max_new_tokens: int = 2048, @@ -80,6 +81,12 @@ def sample_manual_loop_no_classes( min_val = top_k_vals[..., -1, None] cfg_logits[cfg_logits < min_val] = remove_logit_value + if min_p is not None and min_p > 0: + probs = torch.softmax(cfg_logits, dim=-1) + p_max = probs.max(dim=-1, keepdim=True).values + indices_to_remove = probs < (min_p * p_max) + cfg_logits[indices_to_remove] = remove_logit_value + if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) @@ -110,7 +117,7 @@ def sample_manual_loop_no_classes( return output_audio_codes -def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0): +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000): positive = [[token for token, _ in inner_list] for inner_list in positive] positive = positive[0] @@ -134,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102 paddings = [] ids = [positive] - return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): @@ -192,6 +199,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): temperature = kwargs.get("temperature", 0.85) top_p = kwargs.get("top_p", 0.9) top_k = kwargs.get("top_k", 0.0) + min_p = kwargs.get("min_p", 0.000) duration = math.ceil(duration) kwargs["duration"] = duration @@ -239,6 +247,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): "temperature": temperature, "top_p": top_p, "top_k": top_k, + "min_p": min_p, } return out @@ -299,7 +308,7 @@ class ACE15TEModel(torch.nn.Module): lm_metadata = token_weight_pairs["lm_metadata"] if lm_metadata["generate_audio_codes"]: - audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"]) out["audio_codes"] = [audio_codes] return base_out, None, out diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index dde5bbd2a..9cf84ab4d 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode): io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), + io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), ], outputs=[io.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput: - tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k) + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) conditioning = clip.encode_from_tokens_scheduled(tokens) return io.NodeOutput(conditioning) From 66c18522fbcde5b62731e3fb080a84b14e3dacfc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 11 Feb 2026 19:12:16 -0800 Subject: [PATCH 094/215] Add a tip for common error. (#12414) --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 896862c6b..f549a2f0f 100644 --- a/execution.py +++ b/execution.py @@ -623,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() + elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type: + tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected." error_details = { "node_id": real_node_id, From 4a93a62371b64f9d11a140a09faf985c48902d2e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:38:51 +0200 Subject: [PATCH 095/215] fix(api-nodes): add separate retry budget for 429 rate limit responses (#12421) --- comfy_api_nodes/util/client.py | 207 ++++++++++++----------- comfy_api_nodes/util/download_helpers.py | 30 ++-- comfy_api_nodes/util/request_logger.py | 66 ++++---- comfy_api_nodes/util/upload_helpers.py | 55 +++--- 4 files changed, 177 insertions(+), 181 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 391748e7a..94886af7b 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -57,6 +57,7 @@ class _RequestConfig: files: dict[str, Any] | list[tuple[str, Any]] | None multipart_parser: Callable | None max_retries: int + max_retries_on_rate_limit: int retry_delay: float retry_backoff: float wait_label: str = "Waiting" @@ -65,6 +66,7 @@ class _RequestConfig: final_label_on_success: str | None = "Completed" progress_origin_ts: float | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None + is_rate_limited: Callable[[int, Any], bool] | None = None @dataclass @@ -78,7 +80,7 @@ class _PollUIState: active_since: float | None = None # start time of current active interval (None if queued) -_RETRY_STATUS = {408, 429, 500, 502, 503, 504} +_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] @@ -103,6 +105,8 @@ async def sync_op( final_label_on_success: str | None = "Completed", progress_origin_ts: float | None = None, monitor_progress: bool = True, + max_retries_on_rate_limit: int = 16, + is_rate_limited: Callable[[int, Any], bool] | None = None, ) -> M: raw = await sync_op_raw( cls, @@ -122,6 +126,8 @@ async def sync_op( final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, monitor_progress=monitor_progress, + max_retries_on_rate_limit=max_retries_on_rate_limit, + is_rate_limited=is_rate_limited, ) if not isinstance(raw, dict): raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") @@ -194,6 +200,8 @@ async def sync_op_raw( final_label_on_success: str | None = "Completed", progress_origin_ts: float | None = None, monitor_progress: bool = True, + max_retries_on_rate_limit: int = 16, + is_rate_limited: Callable[[int, Any], bool] | None = None, ) -> dict[str, Any] | bytes: """ Make a single network request. @@ -222,6 +230,8 @@ async def sync_op_raw( final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, price_extractor=price_extractor, + max_retries_on_rate_limit=max_retries_on_rate_limit, + is_rate_limited=is_rate_limited, ) return await _request_base(cfg, expect_binary=as_binary) @@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str: if status == 409: return "There is a problem with your account. Please contact support@comfy.org." if status == 429: - return "Rate Limit Exceeded: Please try again later." + return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again." try: if isinstance(body, dict): err = body.get("error") @@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() attempt = 0 delay = cfg.retry_delay + rate_limit_attempts = 0 + rate_limit_delay = cfg.retry_delay operation_succeeded: bool = False final_elapsed_seconds: int | None = None extracted_price: float | None = None @@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_headers["Content-Type"] = "application/json" payload_kw["json"] = cfg.data or {} - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=dict(payload_headers) if payload_headers else None, - request_params=dict(params) if params else None, - request_data=request_body_log, - ) - except Exception as _log_e: - logging.debug("[DEBUG] request logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) req_coro = sess.request(method, url, params=params, **payload_kw) req_task = asyncio.create_task(req_coro) @@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): body = await resp.json() except (ContentTypeError, json.JSONDecodeError): body = await resp.text() - if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + should_retry = False + wait_time = 0.0 + retry_label = "" + is_rl = resp.status == 429 or ( + cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body) + ) + if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit: + rate_limit_attempts += 1 + wait_time = min(rate_limit_delay, 30.0) + rate_limit_delay *= cfg.retry_backoff + retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}" + should_retry = True + elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries: + wait_time = delay + delay *= cfg.retry_backoff + retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}" + should_retry = True + + if should_retry: logging.warning( - "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + "HTTP %s %s -> %s. Waiting %.2fs (%s).", method, url, resp.status, - delay, - attempt, - cfg.max_retries, + wait_time, + retry_label, ) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=body, - error_message=_friendly_http_message(resp.status, body), - ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) - - await sleep_with_interrupt( - delay, - cfg.node_cls, - cfg.wait_label if cfg.monitor_progress else None, - start_time if cfg.monitor_progress else None, - cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, - ) - delay *= cfg.retry_backoff - continue - msg = _friendly_http_message(resp.status, body) - try: request_logger.log_request_response( operation_id=operation_id, request_method=method, @@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): response_status_code=resp.status, response_headers=dict(resp.headers), response_content=body, - error_message=msg, + error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)", ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) + await sleep_with_interrupt( + wait_time, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + continue + msg = _friendly_http_message(resp.status, body) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) raise Exception(msg) if expect_binary: @@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): bytes_payload = bytes(buff) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=bytes_payload, - ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) return bytes_payload else: try: @@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=response_content_to_log, - ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) return payload except ProcessingInterrupted: logging.debug("Polling was interrupted by user") raise except (ClientError, OSError) as e: - if attempt <= cfg.max_retries: + if (attempt - rate_limit_attempts) <= cfg.max_retries: logging.warning( "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", method, url, delay, - attempt, + attempt - rate_limit_attempts, cfg.max_retries, str(e), ) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=dict(payload_headers) if payload_headers else None, - request_params=dict(params) if params else None, - request_data=request_body_log, - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) - except Exception as _log_e: - logging.debug("[DEBUG] request error logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) await sleep_with_interrupt( delay, cfg.node_cls, @@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): continue diag = await _diagnose_connectivity() if not diag["internet_accessible"]: - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=dict(payload_headers) if payload_headers else None, - request_params=dict(params) if params else None, - request_data=request_body_log, - error_message=f"LocalNetworkError: {str(e)}", - ) - except Exception as _log_e: - logging.debug("[DEBUG] final error logging failed: %s", _log_e) - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - try: request_logger.log_request_response( operation_id=operation_id, request_method=method, @@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): request_headers=dict(payload_headers) if payload_headers else None, request_params=dict(params) if params else None, request_data=request_body_log, - error_message=f"ApiServerError: {str(e)}", + error_message=f"LocalNetworkError: {str(e)}", ) - except Exception as _log_e: - logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) raise ApiServerError( f"The API server at {default_base_url()} is currently unreachable. " f"The service may be experiencing issues." diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 78bcf1fa1..aa588d038 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -167,27 +167,25 @@ async def download_url_to_bytesio( with contextlib.suppress(Exception): dest.seek(0) - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=f"[streamed {written} bytes to dest]", - ) + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None except (ClientError, OSError) as e: if attempt <= max_retries: - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) await sleep_with_interrupt(delay, cls, None, None, None) delay *= retry_backoff continue diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index e0cb4428d..fe0543d9b 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -8,7 +8,6 @@ from typing import Any import folder_paths -# Get the logger instance logger = logging.getLogger(__name__) @@ -91,38 +90,41 @@ def log_request_response( Filenames are sanitized and length-limited for cross-platform safety. If we still fail to write, we fall back to appending into api.log. """ - log_dir = get_log_directory() - filepath = _build_log_filepath(log_dir, operation_id, request_url) - - log_content: list[str] = [] - log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") - log_content.append(f"Operation ID: {operation_id}") - log_content.append("-" * 30 + " REQUEST " + "-" * 30) - log_content.append(f"Method: {request_method}") - log_content.append(f"URL: {request_url}") - if request_headers: - log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") - if request_params: - log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") - if request_data is not None: - log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") - - log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) - if response_status_code is not None: - log_content.append(f"Status Code: {response_status_code}") - if response_headers: - log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") - if response_content is not None: - log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") - if error_message: - log_content.append(f"Error:\n{error_message}") - try: - with open(filepath, "w", encoding="utf-8") as f: - f.write("\n".join(log_content)) - logger.debug("API log saved to: %s", filepath) - except Exception as e: - logger.error("Error writing API log to %s: %s", filepath, str(e)) + log_dir = get_log_directory() + filepath = _build_log_filepath(log_dir, operation_id, request_url) + + log_content: list[str] = [] + log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") + log_content.append(f"Operation ID: {operation_id}") + log_content.append("-" * 30 + " REQUEST " + "-" * 30) + log_content.append(f"Method: {request_method}") + log_content.append(f"URL: {request_url}") + if request_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") + if request_params: + log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") + if request_data is not None: + log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") + + log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) + if response_status_code is not None: + log_content.append(f"Status Code: {response_status_code}") + if response_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") + if response_content is not None: + log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") + if error_message: + log_content.append(f"Error:\n{error_message}") + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write("\n".join(log_content)) + logger.debug("API log saved to: %s", filepath) + except Exception as e: + logger.error("Error writing API log to %s: %s", filepath, str(e)) + except Exception as _log_e: + logging.debug("[DEBUG] log_request_response failed: %s", _log_e) if __name__ == '__main__': diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 83d936ce1..7cc565263 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -255,17 +255,14 @@ async def upload_file( monitor_task = asyncio.create_task(_monitor()) sess: aiohttp.ClientSession | None = None try: - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_params=None, - request_data=f"[File data {len(data)} bytes]", - ) - except Exception as e: - logging.debug("[DEBUG] upload request logging failed: %s", e) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) sess = aiohttp.ClientSession(timeout=timeout) req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) @@ -311,31 +308,27 @@ async def upload_file( delay *= retry_backoff continue raise Exception(f"Failed to upload (HTTP {resp.status}).") - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - except Exception as e: - logging.debug("[DEBUG] upload response logging failed: %s", e) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None except (aiohttp.ClientError, OSError) as e: if attempt <= max_retries: - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_data=f"[File data {len(data)} bytes]", - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) await sleep_with_interrupt( delay, cls, From 117e2143543dd649d47345e183748a82d48d12d3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:51:50 -0800 Subject: [PATCH 096/215] ModelPatcherDynamic: force load non leaf weights (#12433) The current behaviour of the default ModelPatcher is to .to a model only if its fully loaded, which is how random non-leaf weights get loaded in non-LowVRAM conditions. The however means they never get loaded in dynamic_vram. In the dynamic_vram case, force load them to the GPU. --- comfy/model_patcher.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f278fccac..b1d907ba4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -679,18 +679,19 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False): + def _load_list(self, prio_comfy_cast_weights=False, default_device=None): loading = [] for n, m in self.model.named_modules(): - params = [] - skip = False - for name, param in m.named_parameters(recurse=False): - params.append(name) + default = False + params = { name: param for name, param in m.named_parameters(recurse=False) } for name, param in m.named_parameters(recurse=True): if name not in params: - skip = True # skip random weights in non leaf modules + default = True # default random weights in non leaf modules break - if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + if default and default_device is not None: + for param in params.values(): + param.data = param.data.to(device=default_device) + if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): @@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher): #with pin and unpin syncrhonization which can be expensive for small weights #with a high layer rate (e.g. autoregressive LLMs). #prioritize the non-comfy weights (note the order reverse). - loading = self._load_list(prio_comfy_cast_weights=True) + loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) for x in loading: @@ -1579,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher): return 0 if vbar is None else vbar.free_memory(memory_to_free) def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True) + loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) for x in loading: _, _, _, _, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) From ae79e33345dc893f8e0632c380c0e91dc09ac6e8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:56:42 -0800 Subject: [PATCH 097/215] llama: use a more efficient rope implementation (#12434) Get rid of the cat and unary negation and inplace add-cmul the two halves of the rope. Precompute -sin once at the start of the model rather than every transformer block. This is slightly faster on both GPU and CPU bound setups. --- comfy/text_encoders/llama.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index b6735d210..54f3d5595 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -355,13 +355,6 @@ class RMSNorm(nn.Module): -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): if not isinstance(theta, list): theta = [theta] @@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di else: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) - out.append((cos, sin)) + sin_split = sin.shape[-1] // 2 + out.append((cos, sin[..., : sin_split], -sin[..., sin_split :])) if len(out) == 1: return out[0] return out - def apply_rope(xq, xk, freqs_cis): org_dtype = xq.dtype cos = freqs_cis[0] sin = freqs_cis[1] - q_embed = (xq * cos) + (rotate_half(xq) * sin) - k_embed = (xk * cos) + (rotate_half(xk) * sin) + nsin = freqs_cis[2] + + q_embed = (xq * cos) + q_split = q_embed.shape[-1] // 2 + q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin) + q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin) + + k_embed = (xk * cos) + k_split = k_embed.shape[-1] // 2 + k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin) + k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin) + return q_embed.to(org_dtype), k_embed.to(org_dtype) From e03fe8b5919a23a473cea6e53f916f7403c082a5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 12 Feb 2026 20:29:12 -0800 Subject: [PATCH 098/215] Update command to install AMD stable linux pytorch. (#12437) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 96dc2904b..3ccdc9c19 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ Put your VAE in: models/vae AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1``` This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: From 8902907d7ab949ce42dd9b658b4a4582ed9fb630 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 13 Feb 2026 12:29:37 -0800 Subject: [PATCH 099/215] dynamic_vram: Training fixes (#12442) --- comfy/model_patcher.py | 4 ++++ comfy_extras/nodes_train.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b1d907ba4..67dce088e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1561,6 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher): allocated_size += weight_size vbar.set_watermark_limit(allocated_size) + move_weight_functions(m, device_to) + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") self.model.device = device_to @@ -1601,6 +1603,8 @@ class ModelPatcherDynamic(ModelPatcher): if unpatch_weights: self.partially_unload_ram(1e32) self.partially_unload(None, 1e32) + for m in self.model.modules(): + move_weight_functions(m, device_to) def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 630eedc9f..aa2d88673 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Depth level for gradient checkpointing.", + tooltip="Offload the Model to RAM. Requires Bypass Mode.", ), io.Combo.Input( "existing_lora", @@ -1124,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) + if mp.is_dynamic(): + if not bypass_mode: + logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") + bypass_mode = True + offloading = True + elif offloading: + if not bypass_mode: + logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") + # Prepare latents and compute counts latents, num_images, multi_res = _prepare_latents_and_count( latents, dtype, bucket_mode From e1add563f9e89026e8c4e8825a2b279fbd67d23a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 12:35:13 -0800 Subject: [PATCH 100/215] Use torch RMSNorm for flux models and refactor hunyuan video code. (#12432) --- comfy/controlnet.py | 1 + comfy/ldm/chroma/layers.py | 3 +- comfy/ldm/chroma_radiance/layers.py | 8 ++--- comfy/ldm/flux/layers.py | 53 +++++++---------------------- comfy/ldm/flux/model.py | 3 +- comfy/ldm/hunyuan_video/model.py | 11 +++--- comfy/lora_convert.py | 2 +- comfy/model_detection.py | 18 +++++++--- comfy/supported_models.py | 32 +++++++++++++++-- comfy/utils.py | 12 +++---- 10 files changed, 74 insertions(+), 69 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9e1e704e0..8336412f2 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -560,6 +560,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + sd = model_config.process_unet_state_dict(sd) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2d5684348..df348a8ed 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -3,7 +3,6 @@ from torch import Tensor, nn from comfy.ldm.flux.layers import ( MLPEmbedder, - RMSNorm, ModulationOut, ) @@ -29,7 +28,7 @@ class Approximator(nn.Module): super().__init__() self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) - self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)]) self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) @property diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py index 3c7bc9b6b..08d31e0ba 100644 --- a/comfy/ldm/chroma_radiance/layers.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -4,8 +4,6 @@ from functools import lru_cache import torch from torch import nn -from comfy.ldm.flux.layers import RMSNorm - class NerfEmbedder(nn.Module): """ @@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module): # We now need to generate parameters for 3 matrices. total_params = 3 * hidden_size_x**2 * mlp_ratio self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) - self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device) self.mlp_ratio = mlp_ratio @@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module): class NerfFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module): class NerfFinalLayerConv(nn.Module): def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.conv = operations.Conv2d( in_channels=hidden_size, out_channels=out_channels, diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 60f2bdae2..1f2975fb1 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -5,8 +5,6 @@ import torch from torch import Tensor, nn from .math import attention, rope -import comfy.ops -import comfy.ldm.common_dit class EmbedND(nn.Module): @@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) - - def forward(self, x: Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) - class QKNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): super().__init__() - self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) + self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) @@ -169,7 +159,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -197,8 +187,6 @@ class DoubleStreamBlock(nn.Module): self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) - self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): if self.modulation: img_mod1, img_mod2 = self.img_mod(vec) @@ -224,32 +212,17 @@ class DoubleStreamBlock(nn.Module): del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - if self.flipped_img_txt: - q = torch.cat((img_q, txt_q), dim=2) - del img_q, txt_q - k = torch.cat((img_k, txt_k), dim=2) - del img_k, txt_k - v = torch.cat((img_v, txt_v), dim=2) - del img_v, txt_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v + # run actual attention + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] - else: - q = torch.cat((txt_q, img_q), dim=2) - del txt_q, img_q - k = torch.cat((txt_k, img_k), dim=2) - del txt_k, img_k - v = torch.cat((txt_v, img_v), dim=2) - del txt_v, img_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f40c2a7a9..260ccad7e 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -16,7 +16,6 @@ from .layers import ( SingleStreamBlock, timestep_embedding, Modulation, - RMSNorm ) @dataclass @@ -81,7 +80,7 @@ class Flux(nn.Module): self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) if params.txt_norm: - self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device) else: self.txt_norm = None diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 55ab550f8..563f28f6b 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, - flipped_img_txt=True, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module): extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) - ids = torch.cat((img_ids, txt_ids), dim=1) + ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) img_len = img.shape[1] if txt_mask is not None: attn_mask_len = img_len + txt.shape[1] attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) - attn_mask[:, 0, img_len:] = txt_mask + attn_mask[:, 0, :txt.shape[1]] = txt_mask else: attn_mask = None @@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module): if add is not None: img += add - img = torch.cat((img, txt), 1) + img = torch.cat((txt, img), 1) transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" @@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module): if i < len(control_o): add = control_o[i] if add is not None: - img[:, : img_len] += add + img[:, txt.shape[1]: img_len + txt.shape[1]] += add - img = img[:, : img_len] + img = img[:, txt.shape[1]: img_len + txt.shape[1]] if ref_latent is not None: img = img[:, ref_latent.shape[1]:] diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 9d8d21efe..749e81df3 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -5,7 +5,7 @@ import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux sd_out = {} for k in sd: - k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) + k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight")) sd_out[k_to] = sd[k] sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]]) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e8ad725df..30ea03e8e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def any_suffix_in(keys, prefix, main, suffix_list=[]): + for x in suffix_list: + if "{}{}{}".format(prefix, main, x) in keys: + return True + return False + def calculate_transformer_depth(prefix, state_dict_keys, state_dict): context_dim = None use_linear_in_transformer = False @@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["meanflow_sum"] = False return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) + if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: dit_config["image_model"] = "flux2" @@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma + + if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 dit_config["out_channels"] = 64 @@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 - if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + + if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 @@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 dit_config["nerf_tile_size"] = 512 - dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred dit_config["use_x0"] = True @@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys - dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"]) if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model dit_config["txt_ids_dims"] = [1, 2] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d33db7507..c28be1716 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith("_norm.scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] @@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE): key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.") key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.") key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.") - key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale") - key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale") + key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight") + key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight") key_out = key_out.replace("_attn_proj.", "_attn.proj.") key_out = key_out.replace(".modulation.linear.", ".modulation.lin.") key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.") + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) out_sd[key_out] = state_dict[k] return out_sd @@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE): latent_format = latent_formats.Hunyuan3Dv2 + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, device=device) diff --git a/comfy/utils.py b/comfy/utils.py index e0a94e2e1..d553a7c1b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.linear_in.bias": "txt_mlp.0.bias", "ff_context.linear_out.weight": "txt_mlp.2.weight", "ff_context.linear_out.bias": "txt_mlp.2.bias", - "attn.norm_q.weight": "img_attn.norm.query_norm.scale", - "attn.norm_k.weight": "img_attn.norm.key_norm.scale", - "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", - "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + "attn.norm_q.weight": "img_attn.norm.query_norm.weight", + "attn.norm_k.weight": "img_attn.norm.key_norm.weight", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight", } for k in block_map: @@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", - "attn.norm_q.weight": "norm.query_norm.scale", - "attn.norm_k.weight": "norm.key_norm.scale", + "attn.norm_q.weight": "norm.query_norm.weight", + "attn.norm_k.weight": "norm.key_norm.weight", "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 "attn.to_out.weight": "linear2.weight", # Flux 2 } From 831351a29e91ea758437227c2f3c915a6be6d1a6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:15:23 -0800 Subject: [PATCH 101/215] Support generating attention masks for left padded text encoders. (#12454) --- comfy/sd1_clip.py | 15 +++++++++++---- comfy/text_encoders/ace15.py | 8 +------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4c817d468..b564d1529 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) + pad_token = self.special_tokens.get("pad", -1) if end_token is None: - cmp_token = self.special_tokens.get("pad", -1) + cmp_token = pad_token else: cmp_token = end_token @@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): other_embeds = [] eos = False index = 0 + left_pad = False for y in x: if isinstance(y, numbers.Integral): - if eos: + token = int(y) + if index == 0 and token == pad_token: + left_pad = True + + if eos or (left_pad and token == pad_token): attention_mask.append(0) else: attention_mask.append(1) - token = int(y) + left_pad = False + tokens_temp += [token] - if not eos and token == cmp_token: + if not eos and token == cmp_token and not left_pad: if end_token is None: attention_mask[-1] = 0 eos = True diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 0fdd4669f..f135d74c1 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -10,7 +10,6 @@ import comfy.utils def sample_manual_loop_no_classes( model, ids=None, - paddings=[], execution_dtype=None, cfg_scale: float = 2.0, temperature: float = 0.85, @@ -36,9 +35,6 @@ def sample_manual_loop_no_classes( embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) embeds_batch = embeds.shape[0] - for i, t in enumerate(paddings): - attention_mask[i, :t] = 0 - attention_mask[i, t:] = 1 output_audio_codes = [] past_key_values = [] @@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102 pos_pad = (len(negative) - len(positive)) positive = [model.special_tokens["pad"]] * pos_pad + positive - paddings = [pos_pad, neg_pad] ids = [positive, negative] else: - paddings = [] ids = [positive] - return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): From 726af73867c18c5ca8b980a2c28401d77e5b365a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:21:10 -0800 Subject: [PATCH 102/215] Fix some custom nodes. (#12455) --- comfy/ldm/flux/layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 1f2975fb1..3518a1922 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -6,6 +6,8 @@ from torch import Tensor, nn from .math import attention, rope +# Fix import for some custom nodes, TODO: delete eventually. +RMSNorm = None class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list): From 712efb466b9379e6761802c44027783d37d96a87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:56:54 -0800 Subject: [PATCH 103/215] Add left padding to LTXAV text encoder. (#12456) --- comfy/text_encoders/lt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 3f87dfd6a..9cf87c0b2 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs): class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module): token_weight_pairs = token_weight_pairs["gemma3_12b"] out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out = out[:, :, -torch.sum(extra["attention_mask"]).item():] out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) @@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module): token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + num_tokens = max(num_tokens, 64) return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): From dc9822b7df4785e77690e93b0e09feaff01e2e12 Mon Sep 17 00:00:00 2001 From: krigeta <75309361+krigeta@users.noreply.github.com> Date: Sat, 14 Feb 2026 08:53:52 +0530 Subject: [PATCH 104/215] Add working Qwen 2512 ControlNet (Fun ControlNet) support (#12359) --- comfy/controlnet.py | 73 +++++++++++ comfy/ldm/qwen_image/controlnet.py | 190 +++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8336412f2..ba670b16d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -297,6 +297,30 @@ class ControlNet(ControlBase): self.model_sampling_current = None super().cleanup() + +class QwenFunControlNet(ControlNet): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): + # Fun checkpoints are more sensitive to high strengths in the generic + # ControlNet merge path. Use a soft response curve so strength=1.0 stays + # unchanged while >1 grows more gently. + original_strength = self.strength + self.strength = math.sqrt(max(self.strength, 0.0)) + try: + return super().get_control(x_noisy, t, cond, batched_number, transformer_options) + finally: + self.strength = original_strength + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.set_extra_arg("base_model", model.diffusion_model) + + def copy(self): + c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + class ControlLoraOps: class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, @@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control + +def load_controlnet_qwen_fun(sd, model_options={}): + load_device = comfy.model_management.get_torch_device() + weight_dtype = comfy.utils.weight_dtype(sd) + unet_dtype = model_options.get("dtype", weight_dtype) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) + + in_features = sd["control_img_in.weight"].shape[1] + inner_dim = sd["control_img_in.weight"].shape[0] + + block_weight = sd["control_blocks.0.attn.to_q.weight"] + attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0] + num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim)) + + model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel( + control_in_features=in_features, + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + operations=operations, + device=comfy.model_management.unet_offload_device(), + dtype=unet_dtype, + ) + model = controlnet_load_state_dict(model, sd) + + latent_format = comfy.latent_formats.Wan21() + control = QwenFunControlNet( + model, + compression_ratio=1, + latent_format=latent_format, + # Fun checkpoints already expect their own 33-channel context handling. + # Enabling generic concat_mask injects an extra mask channel at apply-time + # and breaks the intended fallback packing path. + concat_mask=False, + load_device=load_device, + manual_cast_dtype=manual_cast_dtype, + extra_conds=[], + ) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data: + return load_controlnet_qwen_fun(controlnet_data, model_options=model_options) elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index a6d408104..c0aae9240 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -2,6 +2,196 @@ import torch import math from .model import QwenImageTransformer2DModel +from .model import QwenImageTransformerBlock + + +class QwenImageFunControlBlock(QwenImageTransformerBlock): + def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None): + super().__init__( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.has_before_proj = has_before_proj + if has_before_proj: + self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + +class QwenImageFunControlNetModel(torch.nn.Module): + def __init__( + self, + control_in_features=132, + inner_dim=3072, + num_attention_heads=24, + attention_head_dim=128, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.main_model_double = main_model_double + self.injection_layers = tuple(injection_layers) + # Keep base hint scaling at 1.0 so user-facing strength behaves similarly + # to the reference Gen2/VideoX implementation around strength=1. + self.hint_scale = 1.0 + self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype) + + self.control_blocks = torch.nn.ModuleList([]) + for i in range(num_control_blocks): + self.control_blocks.append( + QwenImageFunControlBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + has_before_proj=(i == 0), + dtype=dtype, + device=device, + operations=operations, + ) + ) + + def _process_hint_tokens(self, hint): + if hint is None: + return None + if hint.ndim == 4: + hint = hint.unsqueeze(2) + + # Fun checkpoints are trained with 33 latent channels before 2x2 packing: + # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features. + # Default behavior (no inpaint input in stock Apply ControlNet) should use + # zeros for mask/inpaint branches, matching VideoX fallback semantics. + expected_c = self.control_img_in.weight.shape[1] // 4 + if hint.shape[1] == 16 and expected_c == 33: + zeros_mask = torch.zeros_like(hint[:, :1]) + zeros_inpaint = torch.zeros_like(hint) + hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1) + + bs, c, t, h, w = hint.shape + hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view( + orig_shape[0], + orig_shape[1], + orig_shape[-3], + orig_shape[-2] // 2, + 2, + orig_shape[-1] // 2, + 2, + ) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape( + bs, + t * ((h + 1) // 2) * ((w + 1) // 2), + c * 4, + ) + + expected_in = self.control_img_in.weight.shape[1] + cur_in = hidden_states.shape[-1] + if cur_in < expected_in: + pad = torch.zeros( + (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states = torch.cat([hidden_states, pad], dim=-1) + elif cur_in > expected_in: + hidden_states = hidden_states[:, :, :expected_in] + + return hidden_states + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + hint=None, + transformer_options={}, + base_model=None, + **kwargs, + ): + if base_model is None: + raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.") + + encoder_hidden_states_mask = attention_mask + # Keep attention mask disabled inside Fun control blocks to mirror + # VideoX behavior (they rely on seq lengths for RoPE, not masked attention). + encoder_hidden_states_mask = None + + hidden_states, img_ids, _ = base_model.process_img(x) + hint_tokens = self._process_hint_tokens(hint) + if hint_tokens is None: + raise RuntimeError("Qwen Fun ControlNet requires a control hint image.") + + if hint_tokens.shape[1] != hidden_states.shape[1]: + max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1]) + hint_tokens = hint_tokens[:, :max_tokens] + hidden_states = hidden_states[:, :max_tokens] + img_ids = img_ids[:, :max_tokens] + + txt_start = round( + max( + ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ) + ) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous() + + hidden_states = base_model.img_in(hidden_states) + encoder_hidden_states = base_model.txt_norm(context) + encoder_hidden_states = base_model.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + base_model.time_text_embed(timesteps, hidden_states) + if guidance is None + else base_model.time_text_embed(timesteps, guidance, hidden_states) + ) + + c = self.control_img_in(hint_tokens) + + for i, block in enumerate(self.control_blocks): + if i == 0: + c_in = block.before_proj(c) + hidden_states + all_c = [] + else: + all_c = list(torch.unbind(c, dim=0)) + c_in = all_c.pop(-1) + + encoder_hidden_states, c_out = block( + hidden_states=c_in, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + c_skip = block.after_proj(c_out) * self.hint_scale + all_c += [c_skip, c_out] + c = torch.stack(all_c, dim=0) + + hints = torch.unbind(c, dim=0)[:-1] + + controlnet_block_samples = [None] * self.main_model_double + for local_idx, base_idx in enumerate(self.injection_layers): + if local_idx < len(hints) and base_idx < len(controlnet_block_samples): + controlnet_block_samples[base_idx] = hints[local_idx] + + return {"input": controlnet_block_samples} class QwenImageControlNetModel(QwenImageTransformer2DModel): From df1e5e85142746a745a56572b705406b273a594c Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 14 Feb 2026 11:01:10 -0800 Subject: [PATCH 105/215] Update frontend package to 1.38.14 (#12469) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7de6a413c..e939e486a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.38.13 +comfyui-frontend-package==1.38.14 comfyui-workflow-templates==0.8.38 comfyui-embedded-docs==0.4.1 torch From e1ede29d827d573262caede8aeb6cbc98c323c81 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 14 Feb 2026 19:53:52 -0800 Subject: [PATCH 106/215] Remove unsafe pickle loading code that was used on pytorch older than 2.4 (#12473) ComfyUI hasn't started on pytorch 2.4 since last month. --- comfy/checkpoint_pickle.py | 13 ------------- comfy/utils.py | 25 +++++++++++-------------- 2 files changed, 11 insertions(+), 27 deletions(-) delete mode 100644 comfy/checkpoint_pickle.py diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py deleted file mode 100644 index 206551d3c..000000000 --- a/comfy/checkpoint_pickle.py +++ /dev/null @@ -1,13 +0,0 @@ -import pickle - -load = pickle.load - -class Empty: - pass - -class Unpickler(pickle.Unpickler): - def find_class(self, module, name): - #TODO: safe unpickle - if module.startswith("pytorch_lightning"): - return Empty - return super().find_class(module, name) diff --git a/comfy/utils.py b/comfy/utils.py index d553a7c1b..c1ce540b5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,7 +20,7 @@ import torch import math import struct -import comfy.checkpoint_pickle +import comfy.memory_management import safetensors.torch import numpy as np from PIL import Image @@ -38,26 +38,26 @@ import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap -ALWAYS_SAFE_LOAD = False -if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated + +if True: # ckpt/pt file whitelist for safe loading of old sd files class ModelCheckpoint: pass ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" def scalar(*args, **kwargs): - from numpy.core.multiarray import scalar as sc - return sc(*args, **kwargs) + return None scalar.__module__ = "numpy.core.multiarray" from numpy import dtype from numpy.dtypes import Float64DType - from _codecs import encode + + def encode(*args, **kwargs): # no longer necessary on newer torch + return None + encode.__module__ = "_codecs" torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) - ALWAYS_SAFE_LOAD = True logging.info("Checkpoint files will always be loaded safely.") -else: - logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") + # Current as of safetensors 0.7.0 _TYPES = { @@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if MMAP_TORCH_FILES: torch_args["mmap"] = True - if safe_load or ALWAYS_SAFE_LOAD: - pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) - else: - logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt)) - pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) + pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) + if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: From ce4a1ab48d9f723eeaac37f88dde55086b1f233f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 15 Feb 2026 11:31:59 +0200 Subject: [PATCH 107/215] chore(api-nodes): remove "gpt-4o" model (#12467) --- comfy_api_nodes/nodes_openai.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index f05aaab7b..332107a82 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum): o1 = "o1" o3 = "o3" o1_pro = "o1-pro" - gpt_4o = "gpt-4o" gpt_4_1 = "gpt-4.1" gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_nano = "gpt-4.1-nano" @@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode): "usd": [0.01, 0.04], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } - : $contains($m, "gpt-4o") ? { - "type": "list_usd", - "usd": [0.0025, 0.01], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } : $contains($m, "gpt-4.1-nano") ? { "type": "list_usd", "usd": [0.0001, 0.0004], From 596ed686919f11f75be3cf9a79977a07d64002c5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 15 Feb 2026 02:12:30 -0800 Subject: [PATCH 108/215] Node Replacement API (#12014) --- app/node_replace_manager.py | 105 ++++++++++++++++++++++++++ comfy_api/feature_flags.py | 1 + comfy_api/latest/__init__.py | 13 +++- comfy_api/latest/_io.py | 63 ++++++++++++++++ comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_replacements.py | 103 +++++++++++++++++++++++++ nodes.py | 2 + server.py | 5 ++ 8 files changed, 291 insertions(+), 2 deletions(-) create mode 100644 app/node_replace_manager.py create mode 100644 comfy_extras/nodes_replacements.py diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py new file mode 100644 index 000000000..03b603c70 --- /dev/null +++ b/app/node_replace_manager.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from aiohttp import web + +from typing import TYPE_CHECKING, TypedDict +if TYPE_CHECKING: + from comfy_api.latest._io_public import NodeReplace + +from comfy_execution.graph_utils import is_link +import nodes + +class NodeStruct(TypedDict): + inputs: dict[str, str | int | float | bool | tuple[str, int]] + class_type: str + _meta: dict[str, str] + +def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct: + new_node_struct = node_struct.copy() + if empty_inputs: + new_node_struct["inputs"] = {} + else: + new_node_struct["inputs"] = node_struct["inputs"].copy() + new_node_struct["_meta"] = node_struct["_meta"].copy() + return new_node_struct + + +class NodeReplaceManager: + """Manages node replacement registrations.""" + + def __init__(self): + self._replacements: dict[str, list[NodeReplace]] = {} + + def register(self, node_replace: NodeReplace): + """Register a node replacement mapping.""" + self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace) + + def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: + """Get replacements for an old node ID.""" + return self._replacements.get(old_node_id) + + def has_replacement(self, old_node_id: str) -> bool: + """Check if a replacement exists for an old node ID.""" + return old_node_id in self._replacements + + def apply_replacements(self, prompt: dict[str, NodeStruct]): + connections: dict[str, list[tuple[str, str, int]]] = {} + need_replacement: set[str] = set() + for node_number, node_struct in prompt.items(): + class_type = node_struct["class_type"] + # need replacement if not in NODE_CLASS_MAPPINGS and has replacement + if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): + need_replacement.add(node_number) + # keep track of connections + for input_id, input_value in node_struct["inputs"].items(): + if is_link(input_value): + conn_number = input_value[0] + connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1])) + for node_number in need_replacement: + node_struct = prompt[node_number] + class_type = node_struct["class_type"] + replacements = self.get_replacement(class_type) + if replacements is None: + continue + # just use the first replacement + replacement = replacements[0] + new_node_id = replacement.new_node_id + # if replacement is not a valid node, skip trying to replace it as will only cause confusion + if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys(): + continue + # first, replace node id (class_type) + new_node_struct = copy_node_struct(node_struct, empty_inputs=True) + new_node_struct["class_type"] = new_node_id + # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema + # second, replace inputs + if replacement.input_mapping is not None: + for input_map in replacement.input_mapping: + if "set_value" in input_map: + new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"] + elif "old_id" in input_map: + new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]] + # finalize input replacement + prompt[node_number] = new_node_struct + # third, replace outputs + if replacement.output_mapping is not None: + # re-mapping outputs requires changing the input values of nodes that receive connections from this one + if node_number in connections: + for conns in connections[node_number]: + conn_node_number, conn_input_id, old_output_idx = conns + for output_map in replacement.output_mapping: + if output_map["old_idx"] == old_output_idx: + new_output_idx = output_map["new_idx"] + previous_input = prompt[conn_node_number]["inputs"][conn_input_id] + previous_input[1] = new_output_idx + + def as_dict(self): + """Serialize all replacements to dict.""" + return { + k: [v.as_dict() for v in v_list] + for k, v_list in self._replacements.items() + } + + def add_routes(self, routes): + @routes.get("/node_replacements") + async def get_node_replacements(request): + return web.json_response(self.as_dict()) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index de167f037..a90a5ca40 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, + "node_replacements": True, } diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 8542a1dbc..f2399422b 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase): VERSION = "latest" STABLE = False + def __init__(self): + super().__init__() + self.node_replacement = self.NodeReplacement() + self.execution = self.Execution() + + class NodeReplacement(ProxiedSingleton): + async def register(self, node_replace: io.NodeReplace) -> None: + """Register a node replacement mapping.""" + from server import PromptServer + PromptServer.instance.node_replace_manager.register(node_replace) + class Execution(ProxiedSingleton): async def set_progress( self, @@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) - execution: Execution - class ComfyExtension(ABC): async def on_load(self) -> None: """ diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 93cf482ca..95d79c035 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -2030,6 +2030,68 @@ class _UIOutput(ABC): ... +class InputMapOldId(TypedDict): + """Map an old node input to a new node input by ID.""" + new_id: str + old_id: str + +class InputMapSetValue(TypedDict): + """Set a specific value for a new node input.""" + new_id: str + set_value: Any + +InputMap = InputMapOldId | InputMapSetValue +""" +Input mapping for node replacement. Type is inferred by dictionary keys: +- {"new_id": str, "old_id": str} - maps old input to new input +- {"new_id": str, "set_value": Any} - sets a specific value for new input +""" + +class OutputMap(TypedDict): + """Map outputs of node replacement via indexes.""" + new_idx: int + old_idx: int + +class NodeReplace: + """ + Defines a possible node replacement, mapping inputs and outputs of the old node to the new node. + + Also supports assigning specific values to the input widgets of the new node. + + Args: + new_node_id: The class name of the new replacement node. + old_node_id: The class name of the deprecated node. + old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot + connected. The workflow JSON stores widget values by their relative position index, + not by ID. This list maps those positional indexes to input IDs, enabling the + replacement system to correctly identify widget values during node migration. + input_mapping: List of input mappings from old node to new node. + output_mapping: List of output mappings from old node to new node. + """ + def __init__(self, + new_node_id: str, + old_node_id: str, + old_widget_ids: list[str] | None=None, + input_mapping: list[InputMap] | None=None, + output_mapping: list[OutputMap] | None=None, + ): + self.new_node_id = new_node_id + self.old_node_id = old_node_id + self.old_widget_ids = old_widget_ids + self.input_mapping = input_mapping + self.output_mapping = output_mapping + + def as_dict(self): + """Create serializable representation of the node replacement.""" + return { + "new_node_id": self.new_node_id, + "old_node_id": self.old_node_id, + "old_widget_ids": self.old_widget_ids, + "input_mapping": list(self.input_mapping) if self.input_mapping else None, + "output_mapping": list(self.output_mapping) if self.output_mapping else None, + } + + __all__ = [ "FolderType", "UploadType", @@ -2121,4 +2183,5 @@ __all__ = [ "ImageCompare", "PriceBadgeDepends", "PriceBadge", + "NodeReplace", ] diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index a52a90e2c..66dac10b1 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): batched = batch_masks(values) return io.NodeOutput(batched) + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: diff --git a/comfy_extras/nodes_replacements.py b/comfy_extras/nodes_replacements.py new file mode 100644 index 000000000..7684e854c --- /dev/null +++ b/comfy_extras/nodes_replacements.py @@ -0,0 +1,103 @@ +from comfy_api.latest import ComfyExtension, io, ComfyAPI + +api = ComfyAPI() + + +async def register_replacements(): + """Register all built-in node replacements.""" + await register_replacements_longeredge() + await register_replacements_batchimages() + await register_replacements_upscaleimage() + await register_replacements_controlnet() + await register_replacements_load3d() + await register_replacements_preview3d() + await register_replacements_svdimg2vid() + await register_replacements_conditioningavg() + +async def register_replacements_longeredge(): + # No dynamic inputs here + await api.node_replacement.register(io.NodeReplace( + new_node_id="ImageScaleToMaxDimension", + old_node_id="ResizeImagesByLongerEdge", + old_widget_ids=["longer_edge"], + input_mapping=[ + {"new_id": "image", "old_id": "images"}, + {"new_id": "largest_size", "old_id": "longer_edge"}, + {"new_id": "upscale_method", "set_value": "lanczos"}, + ], + # just to test the frontend output_mapping code, does nothing really here + output_mapping=[{"new_idx": 0, "old_idx": 0}], + )) + +async def register_replacements_batchimages(): + # BatchImages node uses Autogrow + await api.node_replacement.register(io.NodeReplace( + new_node_id="BatchImagesNode", + old_node_id="ImageBatch", + input_mapping=[ + {"new_id": "images.image0", "old_id": "image1"}, + {"new_id": "images.image1", "old_id": "image2"}, + ], + )) + +async def register_replacements_upscaleimage(): + # ResizeImageMaskNode uses DynamicCombo + await api.node_replacement.register(io.NodeReplace( + new_node_id="ResizeImageMaskNode", + old_node_id="ImageScaleBy", + old_widget_ids=["upscale_method", "scale_by"], + input_mapping=[ + {"new_id": "input", "old_id": "image"}, + {"new_id": "resize_type", "set_value": "scale by multiplier"}, + {"new_id": "resize_type.multiplier", "old_id": "scale_by"}, + {"new_id": "scale_method", "old_id": "upscale_method"}, + ], + )) + +async def register_replacements_controlnet(): + # T2IAdapterLoader → ControlNetLoader + await api.node_replacement.register(io.NodeReplace( + new_node_id="ControlNetLoader", + old_node_id="T2IAdapterLoader", + input_mapping=[ + {"new_id": "control_net_name", "old_id": "t2i_adapter_name"}, + ], + )) + +async def register_replacements_load3d(): + # Load3DAnimation merged into Load3D + await api.node_replacement.register(io.NodeReplace( + new_node_id="Load3D", + old_node_id="Load3DAnimation", + )) + +async def register_replacements_preview3d(): + # Preview3DAnimation merged into Preview3D + await api.node_replacement.register(io.NodeReplace( + new_node_id="Preview3D", + old_node_id="Preview3DAnimation", + )) + +async def register_replacements_svdimg2vid(): + # Typo fix: SDV → SVD + await api.node_replacement.register(io.NodeReplace( + new_node_id="SVD_img2vid_Conditioning", + old_node_id="SDV_img2vid_Conditioning", + )) + +async def register_replacements_conditioningavg(): + # Typo fix: trailing space in node name + await api.node_replacement.register(io.NodeReplace( + new_node_id="ConditioningAverage", + old_node_id="ConditioningAverage ", + )) + +class NodeReplacementsExtension(ComfyExtension): + async def on_load(self) -> None: + await register_replacements() + + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [] + +async def comfy_entrypoint() -> NodeReplacementsExtension: + return NodeReplacementsExtension() diff --git a/nodes.py b/nodes.py index 91de7a9d7..db5f98408 100644 --- a/nodes.py +++ b/nodes.py @@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom if not isinstance(extension, ComfyExtension): logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.") return False + await extension.on_load() node_list = await extension.get_node_list() if not isinstance(node_list, list): logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.") @@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes(): "nodes_lora_debug.py", "nodes_color.py", "nodes_toolkit.py", + "nodes_replacements.py", ] import_failed = [] diff --git a/server.py b/server.py index 2300393b2..8882e43c4 100644 --- a/server.py +++ b/server.py @@ -40,6 +40,7 @@ from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager +from app.node_replace_manager import NodeReplaceManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -204,6 +205,7 @@ class PromptServer(): self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() + self.node_replace_manager = NodeReplaceManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) @@ -887,6 +889,8 @@ class PromptServer(): if "partial_execution_targets" in json_data: partial_execution_targets = json_data["partial_execution_targets"] + self.node_replace_manager.apply_replacements(prompt) + valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) extra_data = {} if "extra_data" in json_data: @@ -995,6 +999,7 @@ class PromptServer(): self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) + self.node_replace_manager.add_routes(self.routes) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. From e2c71ceb0004da0d8a33dc9e79b31c2324241173 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 15 Feb 2026 15:33:18 +0200 Subject: [PATCH 109/215] feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes (#12428) * feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes * add image output to TencentModelTo3DUV node * commented out two nodes * added rate_limit check to other hunyuan3d nodes --- comfy_api_nodes/apis/hunyuan3d.py | 20 ++ comfy_api_nodes/nodes_hunyuan3d.py | 295 ++++++++++++++++++++++++- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/upload_helpers.py | 21 ++ 4 files changed, 326 insertions(+), 12 deletions(-) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py index 6421c9bd5..e84eba31e 100644 --- a/comfy_api_nodes/apis/hunyuan3d.py +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel): class To3DProTaskQueryRequest(BaseModel): JobId: str = Field(...) + + +class To3DUVFileInput(BaseModel): + Type: str = Field(..., description="File type: GLB, OBJ, or FBX") + Url: str = Field(...) + + +class To3DUVTaskRequest(BaseModel): + File: To3DUVFileInput = Field(...) + + +class TextureEditImageInfo(BaseModel): + Url: str = Field(...) + + +class TextureEditTaskRequest(BaseModel): + File3D: To3DUVFileInput = Field(...) + Image: TextureEditImageInfo | None = Field(None) + Prompt: str | None = Field(None) + EnablePBR: bool | None = Field(None) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 813a7c809..ca002cc60 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -1,31 +1,48 @@ from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.hunyuan3d import ( Hunyuan3DViewImage, InputGenerateType, ResultFile3D, + TextureEditTaskRequest, To3DProTaskCreateResponse, To3DProTaskQueryRequest, To3DProTaskRequest, To3DProTaskResultResponse, + To3DUVFileInput, + To3DUVTaskRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_file_3d, + download_url_to_image_tensor, downscale_image_tensor_by_max_side, poll_op, sync_op, + upload_3d_model_to_comfyapi, upload_image_to_comfyapi, validate_image_dimensions, validate_string, ) -def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None: +def _is_tencent_rate_limited(status: int, body: object) -> bool: + return ( + status == 400 + and isinstance(body, dict) + and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", "")) + ) + + +def get_file_from_response( + response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True +) -> ResultFile3D | None: for i in response_objs: if i.Type.lower() == file_type.lower(): return i + if raise_if_not_found: + raise ValueError(f"'{file_type}' file type is not found in the response.") return None @@ -35,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TencentTextToModelNode", - display_name="Hunyuan3D: Text to Model (Pro)", + display_name="Hunyuan3D: Text to Model", category="api node/3d/Tencent", inputs=[ IO.Combo.Input( @@ -120,6 +137,7 @@ class TencentTextToModelNode(IO.ComfyNode): EnablePBR=generate_type.get("pbr", None), PolygonType=generate_type.get("polygon_type", None), ), + is_rate_limited=_is_tencent_rate_limited, ) if response.Error: raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") @@ -131,11 +149,14 @@ class TencentTextToModelNode(IO.ComfyNode): response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - glb_result = get_file_from_response(result.ResultFile3Ds, "glb") - obj_result = get_file_from_response(result.ResultFile3Ds, "obj") - file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None return IO.NodeOutput( - file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None + f"{task_id}.glb", + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + ), + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id + ), ) @@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TencentImageToModelNode", - display_name="Hunyuan3D: Image(s) to Model (Pro)", + display_name="Hunyuan3D: Image(s) to Model", category="api node/3d/Tencent", inputs=[ IO.Combo.Input( @@ -268,6 +289,7 @@ class TencentImageToModelNode(IO.ComfyNode): EnablePBR=generate_type.get("pbr", None), PolygonType=generate_type.get("polygon_type", None), ), + is_rate_limited=_is_tencent_rate_limited, ) if response.Error: raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") @@ -279,11 +301,257 @@ class TencentImageToModelNode(IO.ComfyNode): response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - glb_result = get_file_from_response(result.ResultFile3Ds, "glb") - obj_result = get_file_from_response(result.ResultFile3Ds, "obj") - file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None return IO.NodeOutput( - file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None + f"{task_id}.glb", + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + ), + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id + ), + ) + + +class TencentModelTo3DUVNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentModelTo3DUVNode", + display_name="Hunyuan3D: Model to UV", + category="api node/3d/Tencent", + description="Perform UV unfolding on a 3D model to generate UV texture. " + "Input model must have less than 30000 faces.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny], + tooltip="Input 3D model (GLB, OBJ, or FBX)", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + IO.File3DFBX.Output(display_name="FBX"), + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj", "fbx"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " + f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DUVTaskRequest( + File=To3DUVFileInput( + Type=file_format.upper(), + Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), + ) + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url), + ) + + +class Tencent3DTextureEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Tencent3DTextureEditNode", + display_name="Hunyuan3D: 3D Texture Edit", + category="api node/3d/Tencent", + description="After inputting the 3D model, perform 3D model texture redrawing.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DFBX, IO.File3DAny], + tooltip="3D model in FBX format. Model should have less than 100000 faces.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.6}""", + ), + ) + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format != "fbx": + raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.") + validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), + response_model=To3DProTaskCreateResponse, + data=TextureEditTaskRequest( + File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + Prompt=prompt, + EnablePBR=True, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"), + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + ) + + +class Tencent3DPartNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Tencent3DPartNode", + display_name="Hunyuan3D: 3D Part", + category="api node/3d/Tencent", + description="Automatically perform component identification and generation based on the model structure.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DFBX, IO.File3DAny], + tooltip="3D model in FBX format. Model should have less than 30000 faces.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'), + ) + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format != "fbx": + raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.") + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DUVTaskRequest( + File=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), ) @@ -293,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension): return [ TencentTextToModelNode, TencentImageToModelNode, + # TencentModelTo3DUVNode, + # Tencent3DTextureEditNode, + Tencent3DPartNode, ] diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 18b020eef..f8a0ba8af 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -33,6 +33,7 @@ from .download_helpers import ( download_url_to_video_output, ) from .upload_helpers import ( + upload_3d_model_to_comfyapi, upload_audio_to_comfyapi, upload_file_to_comfyapi, upload_image_to_comfyapi, @@ -62,6 +63,7 @@ __all__ = [ "sync_op", "sync_op_raw", # Upload helpers + "upload_3d_model_to_comfyapi", "upload_audio_to_comfyapi", "upload_file_to_comfyapi", "upload_image_to_comfyapi", diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 7cc565263..6d1d107a1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -164,6 +164,27 @@ async def upload_video_to_comfyapi( return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) +_3D_MIME_TYPES = { + "glb": "model/gltf-binary", + "obj": "model/obj", + "fbx": "application/octet-stream", +} + + +async def upload_3d_model_to_comfyapi( + cls: type[IO.ComfyNode], + model_3d: Types.File3D, + file_format: str, +) -> str: + """Uploads a 3D model file to ComfyUI API and returns its download URL.""" + return await upload_file_to_comfyapi( + cls, + model_3d.get_data(), + f"{uuid.uuid4()}.{file_format}", + _3D_MIME_TYPES.get(file_format, "application/octet-stream"), + ) + + async def upload_file_to_comfyapi( cls: type[IO.ComfyNode], file_bytes_io: BytesIO, From 2c1d06a4e32900c260bcb3d0888f20edc1e3e5ab Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 16 Feb 2026 03:22:30 +0200 Subject: [PATCH 110/215] feat(api-nodes): add Bria RMBG nodes (#12465) Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/apis/bria.py | 44 +++++- comfy_api_nodes/nodes_bria.py | 220 ++++++++++++++++++++++------ comfy_api_nodes/util/conversions.py | 2 +- 3 files changed, 218 insertions(+), 48 deletions(-) diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py index 9119cacc6..8c496b56c 100644 --- a/comfy_api_nodes/apis/bria.py +++ b/comfy_api_nodes/apis/bria.py @@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel): ) +class BriaRemoveBackgroundRequest(BaseModel): + image: str = Field(...) + sync: bool = Field(False) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on input image moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + seed: int = Field(...) + + class BriaStatusResponse(BaseModel): request_id: str = Field(...) status_url: str = Field(...) warning: str | None = Field(None) -class BriaResult(BaseModel): +class BriaRemoveBackgroundResult(BaseModel): + image_url: str = Field(...) + + +class BriaRemoveBackgroundResponse(BaseModel): + status: str = Field(...) + result: BriaRemoveBackgroundResult | None = Field(None) + + +class BriaImageEditResult(BaseModel): structured_prompt: str = Field(...) image_url: str = Field(...) -class BriaResponse(BaseModel): +class BriaImageEditResponse(BaseModel): status: str = Field(...) - result: BriaResult | None = Field(None) + result: BriaImageEditResult | None = Field(None) + + +class BriaRemoveVideoBackgroundRequest(BaseModel): + video: str = Field(...) + background_color: str = Field(default="transparent", description="Background color for the output video.") + output_container_and_codec: str = Field(...) + preserve_audio: bool = Field(True) + seed: int = Field(...) + + +class BriaRemoveVideoBackgroundResult(BaseModel): + video_url: str = Field(...) + + +class BriaRemoveVideoBackgroundResponse(BaseModel): + status: str = Field(...) + result: BriaRemoveVideoBackgroundResult | None = Field(None) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index d3a52bc1b..4044ee3ea 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -3,7 +3,11 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bria import ( BriaEditImageRequest, - BriaResponse, + BriaRemoveBackgroundRequest, + BriaRemoveBackgroundResponse, + BriaRemoveVideoBackgroundRequest, + BriaRemoveVideoBackgroundResponse, + BriaImageEditResponse, BriaStatusResponse, InputModerationSettings, ) @@ -11,10 +15,12 @@ from comfy_api_nodes.util import ( ApiEndpoint, convert_mask_to_image, download_url_to_image_tensor, - get_number_of_images, + download_url_to_video_output, poll_op, sync_op, - upload_images_to_comfyapi, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_video_duration, ) @@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode): IO.DynamicCombo.Input( "moderation", options=[ + IO.DynamicCombo.Option("false", []), IO.DynamicCombo.Option( "true", [ - IO.Boolean.Input( - "prompt_content_moderation", default=False - ), - IO.Boolean.Input( - "visual_input_moderation", default=False - ), - IO.Boolean.Input( - "visual_output_moderation", default=True - ), + IO.Boolean.Input("prompt_content_moderation", default=False), + IO.Boolean.Input("visual_input_moderation", default=False), + IO.Boolean.Input("visual_output_moderation", default=True), ], ), - IO.DynamicCombo.Option("false", []), ], tooltip="Moderation settings", ), @@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode): mask: Input.Image | None = None, ) -> IO.NodeOutput: if not prompt and not structured_prompt: - raise ValueError( - "One of prompt or structured_prompt is required to be non-empty." - ) - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") + raise ValueError("One of prompt or structured_prompt is required to be non-empty.") mask_url = None if mask is not None: - mask_url = ( - await upload_images_to_comfyapi( - cls, - convert_mask_to_image(mask), - max_images=1, - mime_type="image/png", - wait_label="Uploading mask", - ) - )[0] + mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask") response = await sync_op( cls, ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), data=BriaEditImageRequest( instruction=prompt if prompt else None, structured_instruction=structured_prompt if structured_prompt else None, - images=await upload_images_to_comfyapi( - cls, - image, - max_images=1, - mime_type="image/png", - wait_label="Uploading image", - ), + images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")], mask=mask_url, negative_prompt=negative_prompt if negative_prompt else None, guidance_scale=guidance_scale, seed=seed, model_version=model, steps_num=steps, - prompt_content_moderation=moderation.get( - "prompt_content_moderation", False - ), - visual_input_content_moderation=moderation.get( - "visual_input_moderation", False - ), - visual_output_content_moderation=moderation.get( - "visual_output_moderation", False - ), + prompt_content_moderation=moderation.get("prompt_content_moderation", False), + visual_input_content_moderation=moderation.get("visual_input_moderation", False), + visual_output_content_moderation=moderation.get("visual_output_moderation", False), ), response_model=BriaStatusResponse, ) @@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode): cls, ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), status_extractor=lambda r: r.status, - response_model=BriaResponse, + response_model=BriaImageEditResponse, ) return IO.NodeOutput( await download_url_to_image_tensor(response.result.image_url), @@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode): ) +class BriaRemoveImageBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaRemoveImageBackground", + display_name="Bria Remove Image Background", + category="api node/image/Bria", + description="Remove the background from an image using Bria RMBG 2.0.", + inputs=[ + IO.Image.Input("image"), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option("false", []), + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input("visual_input_moderation", default=False), + IO.Boolean.Input("visual_output_moderation", default=True), + ], + ), + ], + tooltip="Moderation settings", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.018}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + moderation: dict, + seed: int, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"), + data=BriaRemoveBackgroundRequest( + image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"), + sync=False, + visual_input_content_moderation=moderation.get("visual_input_moderation", False), + visual_output_content_moderation=moderation.get("visual_output_moderation", False), + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url)) + + +class BriaRemoveVideoBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaRemoveVideoBackground", + display_name="Bria Remove Video Background", + category="api node/video/Bria", + description="Remove the background from a video using Bria. ", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input( + "background_color", + options=[ + "Black", + "White", + "Gray", + "Red", + "Green", + "Blue", + "Yellow", + "Cyan", + "Magenta", + "Orange", + ], + tooltip="Background color for the output video.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + background_color: str, + seed: int, + ) -> IO.NodeOutput: + validate_video_duration(video, max_duration=60.0) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"), + data=BriaRemoveVideoBackgroundRequest( + video=await upload_video_to_comfyapi(cls, video), + background_color=background_color, + output_container_and_codec="mp4_h264", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + + class BriaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ BriaImageEditNode, + BriaRemoveImageBackground, + BriaRemoveVideoBackground, ] diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 3e37e8a8c..82b6d22a5 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -57,7 +57,7 @@ def tensor_to_bytesio( image: torch.Tensor, *, total_pixels: int | None = 2048 * 2048, - mime_type: str = "image/png", + mime_type: str | None = "image/png", ) -> BytesIO: """Converts a torch.Tensor image to a named BytesIO object. From ecd2a19661ecccd96e26f111af21781f3e613f59 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:28:51 -0800 Subject: [PATCH 111/215] Fix lora Extraction in offload conditions (+ dynamic_vram mode) (#12479) * lora_extract: Add a trange If you bite off more than your GPU can chew, this kinda just hangs. Give a rough indication of progress counting the weights in a trange. * lora_extract: Support on-the-fly patching Use the on-the-fly approach from the regular model saving logic for lora extraction too. Switch off force_cast_weights accordingly. This gets extraction working in dynamic vram while also supporting extraction on GPU offloaded. --- comfy_extras/nodes_lora_extract.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index fb89e03f4..1542d0a88 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -7,6 +7,7 @@ import logging from enum import Enum from typing_extensions import override from comfy_api.latest import ComfyExtension, io +from tqdm.auto import trange CLAMP_QUANTILE = 0.99 @@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD, "full_diff": LORAType.FULL_DIFF} def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False): - comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) + comfy.model_management.load_models_gpu([model_diff]) sd = model_diff.model_state_dict(filter_prefix=prefix_model) - for k in sd: - if k.endswith(".weight"): + sd_keys = list(sd.keys()) + for index in trange(len(sd_keys), unit="weight"): + k = sd_keys[index] + op_keys = sd_keys[index].rsplit('.', 1) + if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff): + continue + op = comfy.utils.get_attr(model_diff.model, op_keys[0]) + if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False): + weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True) + else: weight_diff = sd[k] + + if op_keys[1] == "weight": if lora_type == LORAType.STANDARD: if weight_diff.ndim < 2: if bias_diff: @@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora elif lora_type == LORAType.FULL_DIFF: output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu() - elif bias_diff and k.endswith(".bias"): - output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() + elif bias_diff and op_keys[1] == "bias": + output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu() return output_sd class LoraSave(io.ComfyNode): From c0370044cd467b92f4db63b88029ebc700388d36 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:30:09 -0800 Subject: [PATCH 112/215] MPDynamic: force load flux img_in weight (Fixes flux1 canny+depth lora crash) (#12446) * lora: add weight shape calculations. This lets the loader know if a lora will change the shape of a weight so it can take appropriate action. * MPDynamic: force load flux img_in weight This weight is a bit special, in that the lora changes its geometry. This is rather unique, not handled by existing estimate and doesn't work for either offloading or dynamic_vram. Fix for dynamic_vram as a special case. Ideally we can fully precalculate these lora geometry changes at load time, but just get these models working first. --- comfy/lora.py | 25 +++++++++++++++++++++++++ comfy/model_patcher.py | 35 +++++++++++++++++++++++++++-------- comfy/weight_adapter/base.py | 6 ++++++ comfy/weight_adapter/lora.py | 7 +++++++ 4 files changed, 65 insertions(+), 8 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 44030bcab..279cf38bb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten return padded_tensor +def calculate_shape(patches, weight, key, original_weights=None): + current_shape = weight.shape + + for p in patches: + v = p[1] + offset = p[3] + + # Offsets restore the old shape; lists force a diff without metadata + if offset is not None or isinstance(v, list): + continue + + if isinstance(v, weight_adapter.WeightAdapterBase): + adapter_shape = v.calculate_shape(key) + if adapter_shape is not None: + current_shape = adapter_shape + continue + + # Standard diff logic with padding + if len(v) == 2: + patch_type, patch_data = v[0], v[1] + if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']: + current_shape = patch_data[0].shape + + return current_shape + def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): for p in patches: strength = p[0] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 67dce088e..f01818f50 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1514,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if weight is None: - return 0 + return (False, 0) if key in self.patches: + if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: + return (True, 0) setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) num_patches += 1 else: @@ -1529,7 +1531,13 @@ class ModelPatcherDynamic(ModelPatcher): model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype weight._model_dtype = model_dtype geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) - return comfy.memory_management.vram_aligned_size(geometry) + return (False, comfy.memory_management.vram_aligned_size(geometry)) + + def force_load_param(self, param_key, device_to): + key = key_param_name_to_key(n, param_key) + if key in self.backup: + comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) + self.patch_weight_to_device(key, device_to=device_to) if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1537,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher): m.seed_key = n set_dirty(m, dirty) - v_weight_size = 0 - v_weight_size += setup_param(self, m, n, "weight") - v_weight_size += setup_param(self, m, n, "bias") + force_load, v_weight_size = setup_param(self, m, n, "weight") + force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") + force_load = force_load or force_load_bias + v_weight_size += v_weight_bias - if vbar is not None and not hasattr(m, "_v"): - m._v = vbar.alloc(v_weight_size) - allocated_size += v_weight_size + if force_load: + logging.info(f"Module {n} has resizing Lora - force loading") + force_load_param(self, "weight", device_to) + force_load_param(self, "bias", device_to) + else: + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size else: for param in params: @@ -1606,6 +1620,11 @@ class ModelPatcherDynamic(ModelPatcher): for m in self.model.modules(): move_weight_functions(m, device_to) + keys = list(self.backup.keys()) + for k in keys: + bk = self.backup[k] + comfy.utils.set_attr_param(self.model, k, bk.weight) + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above with self.use_ejected(skip_and_inject_on_exit_only=True): diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index bce89a0e2..d352e066b 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -49,6 +49,12 @@ class WeightAdapterBase: """ raise NotImplementedError + def calculate_shape( + self, + key + ): + return None + def calculate_weight( self, weight, diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index bc4260a8f..8e1261a12 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase): else: return None + def calculate_shape( + self, + key + ): + reshape = self.weights[5] + return tuple(reshape) if reshape is not None else None + def calculate_weight( self, weight, From 88e6370527dbd602851de07d957a8f17b3ca9447 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:43:53 -0800 Subject: [PATCH 113/215] Remove workaround for old pytorch. (#12480) --- comfy/ldm/modules/diffusionmodules/model.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5a22ef030..805592aa5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -102,19 +102,7 @@ class VideoConv3d(nn.Module): return self.conv(x) def interpolate_up(x, scale_factor): - try: - return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest") - except: #operation not implemented for bf16 - orig_shape = list(x.shape) - out_shape = orig_shape[:2] - for i in range(len(orig_shape) - 2): - out_shape.append(round(orig_shape[i + 2] * scale_factor[i])) - out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device) - split = 8 - l = out.shape[1] // split - for i in range(0, out.shape[1], l): - out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype) - return out + return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest") class Upsample(nn.Module): def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0): From 1978f59ffdf242389ded3eec76274a4cbed9cc3d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 17 Feb 2026 06:33:43 +0800 Subject: [PATCH 114/215] chore: update workflow templates to v0.8.42 (#12491) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e939e486a..0930bbbb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.38.14 -comfyui-workflow-templates==0.8.38 +comfyui-workflow-templates==0.8.42 comfyui-embedded-docs==0.4.1 torch torchsde From 4454fab7f003c655e07f059c315e2aae0e5fb087 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 16 Feb 2026 17:09:24 -0800 Subject: [PATCH 115/215] Remove code to support RMSNorm on old pytorch. (#12499) --- comfy/ops.py | 6 ++---- comfy/rmsnorm.py | 55 ++++-------------------------------------------- 2 files changed, 6 insertions(+), 55 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 688937e43..026062f56 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,7 +21,6 @@ import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float -import comfy.rmsnorm import json import comfy.memory_management import comfy.pinned_memory @@ -463,7 +462,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): + class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp): def reset_parameters(self): self.bias = None return None @@ -475,8 +474,7 @@ class disable_weight_init: weight = None bias = None offload_stream = None - x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) uncast_bias_weight(self, weight, bias, offload_stream) return x diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 555542a46..ab7cf14fa 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -1,57 +1,10 @@ import torch import comfy.model_management -import numbers -import logging - -RMSNorm = None - -try: - rms_norm_torch = torch.nn.functional.rms_norm - RMSNorm = torch.nn.RMSNorm -except: - rms_norm_torch = None - logging.warning("Please update pytorch to use native RMSNorm") +RMSNorm = torch.nn.RMSNorm def rms_norm(x, weight=None, eps=1e-6): - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + if weight is None: + return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) else: - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) - - -if RMSNorm is None: - class RMSNorm(torch.nn.Module): - def __init__( - self, - normalized_shape, - eps=1e-6, - elementwise_affine=True, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if isinstance(normalized_shape, numbers.Integral): - # mypy error: incompatible types in assignment - normalized_shape = (normalized_shape,) # type: ignore[assignment] - self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = torch.nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - else: - self.register_parameter("weight", None) - self.bias = None - - def forward(self, x): - return rms_norm(x, self.weight, self.eps) + return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) From b44fc4c589c66e39686239d6eff7d6088668c9a8 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Feb 2026 03:16:19 +0000 Subject: [PATCH 116/215] add venv* to gitignore (#12431) --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4e8cea71e..2700ad5c2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ extra_model_paths.yaml /.vs .vscode/ .idea/ -venv/ +venv*/ .venv/ /web/extensions/* !/web/extensions/logging.js.example From 8a6fbc2dc29d0b15c1e9655c24e7501829249995 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 16 Feb 2026 19:20:21 -0800 Subject: [PATCH 117/215] Allow control_after_generate to be type ControlAfterGenerate in v3 schema (#12187) --- comfy_api/latest/_io.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 95d79c035..d18330d0b 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -75,6 +75,12 @@ class NumberDisplay(str, Enum): slider = "slider" +class ControlAfterGenerate(str, Enum): + fixed = "fixed" + increment = "increment" + decrement = "decrement" + randomize = "randomize" + class _ComfyType(ABC): Type = Any io_type: str = None @@ -263,7 +269,7 @@ class Int(ComfyTypeIO): class Input(WidgetInput): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, + default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None, display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min @@ -345,7 +351,7 @@ class Combo(ComfyTypeIO): tooltip: str=None, lazy: bool=None, default: str | int | Enum = None, - control_after_generate: bool=None, + control_after_generate: bool | ControlAfterGenerate=None, upload: UploadType=None, image_folder: FolderType=None, remote: RemoteOptions=None, @@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI): Type = list[str] class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, + default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) self.multiselect = True @@ -2097,6 +2103,7 @@ __all__ = [ "UploadType", "RemoteOptions", "NumberDisplay", + "ControlAfterGenerate", "comfytype", "Custom", From 18927538a15d44c734653513e9fdbbe1e79a9f0c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:30:34 -0800 Subject: [PATCH 118/215] Implement NAG on all the models based on the Flux code. (#12500) Use the Normalized Attention Guidance node. Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc.. --- comfy/ldm/chroma/model.py | 2 + comfy/ldm/flux/layers.py | 18 ++++++ comfy/ldm/flux/model.py | 2 + comfy/ldm/hunyuan_video/model.py | 2 + comfy/model_patcher.py | 5 +- comfy_extras/nodes_nag.py | 99 ++++++++++++++++++++++++++++++++ nodes.py | 1 + 7 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_nag.py diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 2e8ef0687..9fd865f20 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -152,6 +152,7 @@ class Chroma(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + transformer_options = transformer_options.copy() patches_replace = transformer_options.get("patches_replace", {}) # running on sequences img @@ -228,6 +229,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if i not in self.skip_dit: diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 3518a1922..8b3f500d7 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module): else: (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) @@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v + if "attn1_output_patch" in transformer_patches: + extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] + patch = transformer_patches["attn1_output_patch"] + for p in patch: + attn = p(attn, extra_options) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks @@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module): else: mod = vec + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module): # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v + + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + attn = p(attn, extra_options) + # compute activation in mlp stream, cat again and run second linear layer if self.yak_mlp: mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 260ccad7e..ef4dcf7c5 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -142,6 +142,7 @@ class Flux(nn.Module): attn_mask: Tensor = None, ) -> Tensor: + transformer_options = transformer_options.copy() patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -231,6 +232,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 563f28f6b..b94cdfa87 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module): control=None, transformer_options={}, ) -> Tensor: + transformer_options = transformer_options.copy() patches_replace = transformer_options.get("patches_replace", {}) initial_shape = list(img.shape) @@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f01818f50..21b4ce53e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -406,13 +406,16 @@ class ModelPatcher: def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) + def disable_model_cfg1_optimization(self): + self.model_options["disable_cfg1_optimization"] = True + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function if disable_cfg1_optimization: - self.model_options["disable_cfg1_optimization"] = True + self.disable_model_cfg1_optimization() def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) diff --git a/comfy_extras/nodes_nag.py b/comfy_extras/nodes_nag.py new file mode 100644 index 000000000..033e40eb9 --- /dev/null +++ b/comfy_extras/nodes_nag.py @@ -0,0 +1,99 @@ +import torch +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class NAGuidance(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="NAGuidance", + display_name="Normalized Attention Guidance", + description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to apply NAG to."), + io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."), + io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."), + io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01), + # io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."), + # io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with NAG enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput: + m = model.clone() + + # sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent) + # sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent) + + def nag_attention_output_patch(out, extra_options): + cond_or_uncond = extra_options.get("cond_or_uncond", None) + if cond_or_uncond is None: + return out + + if not (1 in cond_or_uncond and 0 in cond_or_uncond): + return out + + # sigma = extra_options.get("sigmas", None) + # if sigma is not None and len(sigma) > 0: + # sigma = sigma[0].item() + # if sigma > sigma_start or sigma < sigma_end: + # return out + + img_slice = extra_options.get("img_slice", None) + + if img_slice is not None: + orig_out = out + out = out[:, img_slice[0]:img_slice[1]] # only apply on img part + + batch_size = out.shape[0] + half_size = batch_size // len(cond_or_uncond) + + ind_neg = cond_or_uncond.index(1) + ind_pos = cond_or_uncond.index(0) + z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)] + z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)] + + guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0) + + eps = 1e-6 + norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps) + norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps) + + ratio = norm_guided / norm_pos + scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio + + guided_normalized = guided * scale_factor + + z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha) + + if img_slice is not None: + orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final + orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final + return orig_out + else: + out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final + return out + + m.set_model_attn1_output_patch(nag_attention_output_patch) + m.disable_model_cfg1_optimization() + + return io.NodeOutput(m) + + +class NagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + NAGuidance, + ] + + +async def comfy_entrypoint() -> NagExtension: + return NagExtension() diff --git a/nodes.py b/nodes.py index db5f98408..dff56b79c 100644 --- a/nodes.py +++ b/nodes.py @@ -2437,6 +2437,7 @@ async def init_builtin_extra_nodes(): "nodes_color.py", "nodes_toolkit.py", "nodes_replacements.py", + "nodes_nag.py", ] import_failed = [] From c39653163d77161b2df2d57419129a4d6d081aa1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:29:20 -0800 Subject: [PATCH 119/215] Fix anima preprocess text embeds not using right inference dtype. (#12501) --- comfy/model_base.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4a74cb1ce..9dcef8741 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module): xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn - dtype = self.get_dtype() - - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() xc = xc.to(dtype) device = xc.device @@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module): def get_dtype(self): return self.diffusion_model.dtype + def get_dtype_inference(self): + dtype = self.get_dtype() + + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + return dtype + def encode_adm(self, **kwargs): return None @@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module): input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() #TODO: this needs to be tweaked area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) @@ -1165,7 +1167,7 @@ class Anima(BaseModel): t5xxl_ids = t5xxl_ids.unsqueeze(0) if torch.is_inference_mode_enabled(): # if not we are training - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference())) else: out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) From fe52843fe55b92dedaabff684294dd7a115d2204 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Feb 2026 00:39:54 -0500 Subject: [PATCH 120/215] ComfyUI v0.14.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index cf4e89816..8f7f3228e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.13.0" +__version__ = "0.14.0" diff --git a/pyproject.toml b/pyproject.toml index 9dab9a50c..b132bb9c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.13.0" +version = "0.14.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 44f85985215b4d819665e4cec84c00ef87aa9a7a Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Tue, 17 Feb 2026 23:56:44 +0800 Subject: [PATCH 121/215] Fix anima LLM adapter forward when manual cast (#12504) --- comfy/ldm/anima/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py index 6fb51c4a4..6fcf8df90 100644 --- a/comfy/ldm/anima/model.py +++ b/comfy/ldm/anima/model.py @@ -179,8 +179,8 @@ class LLMAdapter(nn.Module): if source_attention_mask.ndim == 2: source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) - x = self.in_proj(self.embed(target_input_ids)) context = source_hidden_states + x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype)) position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) position_embeddings = self.rotary_emb(x, position_ids) From 5284e6bf69b6e2e856c672595fd413fd505377ee Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:07:14 +0200 Subject: [PATCH 122/215] feat(api-nodes): add "viduq3-turbo" model and Vidu3StartEnd node; fix the price badges (#12482) --- comfy_api_nodes/nodes_vidu.py | 226 ++++++++++++++++++++++++++++++++-- 1 file changed, 218 insertions(+), 8 deletions(-) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 80de14dfe..bbe7ebba2 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -54,6 +54,7 @@ async def execute_task( response_model=TaskStatusResponse, status_extractor=lambda r: r.state, progress_extractor=lambda r: r.progress, + price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None, max_poll_attempts=max_poll_attempts, ) if not response.creations: @@ -1306,6 +1307,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode): ), ], ), + IO.DynamicCombo.Option( + "viduq3-turbo", + [ + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "3:4", "4:3", "1:1"], + tooltip="The aspect ratio of the output video.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), ], tooltip="Model to use for video generation.", ), @@ -1334,13 +1365,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), expr=""" ( $res := $lookup(widgets, "model.resolution"); - $base := $lookup({"720p": 0.075, "1080p": 0.1}, $res); - $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); - {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} + $d := $lookup(widgets, "model.duration"); + $contains(widgets.model, "turbo") + ? ( + $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); + {"type":"usd","usd": $rate * $d} + ) + : ( + $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res); + {"type":"usd","usd": $rate * $d} + ) ) """, ), @@ -1409,6 +1447,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): ), ], ), + IO.DynamicCombo.Option( + "viduq3-turbo", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), ], tooltip="Model to use for video generation.", ), @@ -1442,13 +1505,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), expr=""" ( $res := $lookup(widgets, "model.resolution"); - $base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res); - $perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res); - {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} + $d := $lookup(widgets, "model.duration"); + $contains(widgets.model, "turbo") + ? ( + $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); + {"type":"usd","usd": $rate * $d} + ) + : ( + $rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res); + {"type":"usd","usd": $rate * $d} + ) ) """, ), @@ -1481,6 +1551,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(results[0].url)) +class Vidu3StartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3StartEndToVideoNode", + display_name="Vidu Q3 Start/End Frame-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from a start frame, an end frame, and a prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq3-turbo", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt description (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $d := $lookup(widgets, "model.duration"); + $contains(widgets.model, "turbo") + ? ( + $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); + {"type":"usd","usd": $rate * $d} + ) + : ( + $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res); + {"type":"usd","usd": $rate * $d} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + first_frame: Input.Image, + end_frame: Input.Image, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + audio=model["audio"], + images=[ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ], + ) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + class ViduExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1497,6 +1706,7 @@ class ViduExtension(ComfyExtension): ViduMultiFrameVideoNode, Vidu3TextToVideoNode, Vidu3ImageToVideoNode, + Vidu3StartEndToVideoNode, ] From 262abf437b0666f3d00d1f335a526073503e59e4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:25:44 +0200 Subject: [PATCH 123/215] feat(api-nodes): add Recraft V4 nodes (#12502) --- comfy_api_nodes/apis/recraft.py | 45 ++++- comfy_api_nodes/nodes_recraft.py | 272 +++++++++++++++++++++++++++++-- 2 files changed, 297 insertions(+), 20 deletions(-) diff --git a/comfy_api_nodes/apis/recraft.py b/comfy_api_nodes/apis/recraft.py index 0bd7d23b3..78ededd94 100644 --- a/comfy_api_nodes/apis/recraft.py +++ b/comfy_api_nodes/apis/recraft.py @@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = { } -class RecraftModel(str, Enum): - recraftv3 = 'recraftv3' - recraftv2 = 'recraftv2' - - class RecraftImageSize(str, Enum): res_1024x1024 = '1024x1024' res_1365x1024 = '1365x1024' @@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum): res_1707x1024 = '1707x1024' +RECRAFT_V4_SIZES = [ + "1024x1024", + "1536x768", + "768x1536", + "1280x832", + "832x1280", + "1216x896", + "896x1216", + "1152x896", + "896x1152", + "832x1344", + "1280x896", + "896x1280", + "1344x768", + "768x1344", +] + +RECRAFT_V4_PRO_SIZES = [ + "2048x2048", + "3072x1536", + "1536x3072", + "2560x1664", + "1664x2560", + "2432x1792", + "1792x2432", + "2304x1792", + "1792x2304", + "1664x2688", + "1434x1024", + "1024x1434", + "2560x1792", + "1792x2560", +] + + class RecraftColorObject(BaseModel): rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model') @@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel): class RecraftImageGenerationRequest(BaseModel): prompt: str = Field(..., description='The text prompt describing the image to generate') - size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")') + size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")') n: int = Field(..., description='The number of images to generate') negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image') - model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') + model: str = Field(...) style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input') controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process') style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID') strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') random_seed: int | None = Field(None, description="Seed for video generation") - # text_layout class RecraftReturnedObject(BaseModel): diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 3a1f32263..773cb7dbe 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,5 +1,4 @@ from io import BytesIO -from typing import Optional, Union import aiohttp import torch @@ -9,6 +8,8 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.recraft import ( + RECRAFT_V4_PRO_SIZES, + RECRAFT_V4_SIZES, RecraftColor, RecraftColorChain, RecraftControls, @@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import ( RecraftImageGenerationResponse, RecraftImageSize, RecraftIO, - RecraftModel, RecraftStyle, RecraftStyleV3, get_v3_substyles, @@ -39,7 +39,7 @@ async def handle_recraft_file_request( cls: type[IO.ComfyNode], image: torch.Tensor, path: str, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, total_pixels: int = 4096 * 4096, timeout: int = 1024, request=None, @@ -73,11 +73,11 @@ async def handle_recraft_file_request( def recraft_multipart_parser( data, parent_key=None, - formatter: Optional[type[callable]] = None, - converted_to_check: Optional[list[list]] = None, + formatter: type[callable] | None = None, + converted_to_check: list[list] | None = None, is_list: bool = False, return_mode: str = "formdata", # "dict" | "formdata" -) -> Union[dict, aiohttp.FormData]: +) -> dict | aiohttp.FormData: """ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. @@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): node_id="RecraftStyleV3InfiniteStyleLibrary", display_name="Recraft Style - Infinite Style Library", category="api node/image/Recraft", - description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", inputs=[ IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), ], @@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode): data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", size=size, n=n, style=recraft_style.style, @@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode): request = RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", n=n, strength=round(strength, 2), style=recraft_style.style, @@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode): request = RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", n=n, style=recraft_style.style, substyle=recraft_style.substyle, @@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode): data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", size=size, n=n, style=recraft_style.style, @@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): request = RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", n=n, style=recraft_style.style, substyle=recraft_style.substyle, @@ -1078,6 +1078,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): ) +class RecraftV4TextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftV4TextToImageNode", + display_name="Recraft V4 Text to Image", + category="api node/image/Recraft", + description="Generates images using Recraft V4 or V4 Pro models.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt for the image generation. Maximum 10,000 characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "recraftv4", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_SIZES, + default="1024x1024", + tooltip="The size of the generated image.", + ), + ], + ), + IO.DynamicCombo.Option( + "recraftv4_pro", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_PRO_SIZES, + default="2048x2048", + tooltip="The size of the generated image.", + ), + ], + ), + ], + tooltip="The model to use for generation.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]), + expr=""" + ( + $prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25}; + {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + model: dict, + n: int, + seed: int, + recraft_controls: RecraftControls | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model=model["model"], + size=model["size"], + n=n, + controls=recraft_controls.create_api_model() if recraft_controls else None, + ), + max_retries=1, + ) + images = [] + for data in response.data: + with handle_recraft_image_output(): + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) + if len(image.shape) < 4: + image = image.unsqueeze(0) + images.append(image) + return IO.NodeOutput(torch.cat(images, dim=0)) + + +class RecraftV4TextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftV4TextToVectorNode", + display_name="Recraft V4 Text to Vector", + category="api node/image/Recraft", + description="Generates SVG using Recraft V4 or V4 Pro models.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt for the image generation. Maximum 10,000 characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "recraftv4", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_SIZES, + default="1024x1024", + tooltip="The size of the generated image.", + ), + ], + ), + IO.DynamicCombo.Option( + "recraftv4_pro", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_PRO_SIZES, + default="2048x2048", + tooltip="The size of the generated image.", + ), + ], + ), + ], + tooltip="The model to use for generation.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]), + expr=""" + ( + $prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30}; + {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + model: dict, + n: int, + seed: int, + recraft_controls: RecraftControls | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model=model["model"], + size=model["size"], + n=n, + style="vector_illustration", + substyle=None, + controls=recraft_controls.create_api_model() if recraft_controls else None, + ), + max_retries=1, + ) + svg_data = [] + for data in response.data: + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) + return IO.NodeOutput(SVG(svg_data)) + + class RecraftExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1098,6 +1344,8 @@ class RecraftExtension(ComfyExtension): RecraftCreateStyleNode, RecraftColorRGBNode, RecraftControlsNode, + RecraftV4TextToImageNode, + RecraftV4TextToVectorNode, ] From 73c3f869737bbb1035f6b72b2e1068a1a5642764 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 18 Feb 2026 02:25:55 +0800 Subject: [PATCH 124/215] chore: update workflow templates to v0.8.43 (#12507) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0930bbbb8..881d6bd58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.38.14 -comfyui-workflow-templates==0.8.42 +comfyui-workflow-templates==0.8.43 comfyui-embedded-docs==0.4.1 torch torchsde From 19236edfa4d2f66070d66a6b3aee592c9c2ad574 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Feb 2026 13:28:06 -0500 Subject: [PATCH 125/215] ComfyUI v0.14.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 8f7f3228e..f24c15cc5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.14.0" +__version__ = "0.14.1" diff --git a/pyproject.toml b/pyproject.toml index b132bb9c4..51c3d224d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.14.0" +version = "0.14.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 58dcc97dcfadc548ac8d8d5e80741ddfb807d213 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 17 Feb 2026 12:32:27 -0800 Subject: [PATCH 126/215] ops: limit return of requants (#12506) This check was far too broad and the dtype is not a reliable indicator of wanting the requant (as QT returns the compute dtype as the dtype). So explictly plumb whether fp8mm wants the requant or not. --- comfy/ops.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 026062f56..a6c642795 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): offload_stream = None xfer_dest = None @@ -170,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu #FIXME: this is not accurate, we need to be sensitive to the compute dtype x = lowvram_fn(x) if (isinstance(orig, QuantizedTensor) and - (orig.dtype == dtype and len(fns) == 0 or update_weight)): + (want_requant and len(fns) == 0 or update_weight)): seed = comfy.utils.string_to_seed(s.seed_key) y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) - if orig.dtype == dtype and len(fns) == 0: + if want_requant and len(fns) == 0: #The layer actually wants our freshly saved QT x = y elif update_weight: @@ -194,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu return weight, bias, (offload_stream, device if signature is not None else None, None) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -850,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input, compute_dtype=None): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) + def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -881,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - - output = self.forward_comfy_cast_weights(input, compute_dtype) + output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) # Reshape output back to 3D if input was 3D if reshaped_3d: From 6c14f129af4fd94c4197644e6950bddbba0c9e51 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Wed, 18 Feb 2026 06:41:34 +0900 Subject: [PATCH 127/215] Bump comfyui-frontend-package to 1.39.14 (#12494) * Bump comfyui-frontend-package to 1.39.13 * Update requirements.txt --------- Co-authored-by: Christian Byrne --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 881d6bd58..807fea5e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.38.14 +comfyui-frontend-package==1.39.14 comfyui-workflow-templates==0.8.43 comfyui-embedded-docs==0.4.1 torch From 8ad38d2073b019a204f730182dcf5456fb260858 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 17 Feb 2026 20:13:39 -0500 Subject: [PATCH 128/215] BBox widget (#11594) * Boundingbox widget * code improve --------- Co-authored-by: Jedrzej Kosinski Co-authored-by: Christian Byrne --- comfy_api/latest/_io.py | 25 ++++++++++++++++ comfy_extras/nodes_images.py | 56 +++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index d18330d0b..312681249 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1209,6 +1209,30 @@ class Color(ComfyTypeIO): def as_dict(self): return super().as_dict() +@comfytype(io_type="BOUNDING_BOX") +class BoundingBox(ComfyTypeIO): + class BoundingBoxDict(TypedDict): + x: int + y: int + width: int + height: int + Type = BoundingBoxDict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, component: str=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless) + self.component = component + if default is None: + self.default = {"x": 0, "y": 0, "width": 512, "height": 512} + + def as_dict(self): + d = super().as_dict() + if self.component: + d["component"] = self.component + return d + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2190,5 +2214,6 @@ __all__ = [ "ImageCompare", "PriceBadgeDepends", "PriceBadge", + "BoundingBox", "NodeReplace", ] diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index cb4fb24a1..23419a65d 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -23,8 +23,9 @@ class ImageCrop(IO.ComfyNode): return IO.Schema( node_id="ImageCrop", search_aliases=["trim"], - display_name="Image Crop", + display_name="Image Crop (Deprecated)", category="image/transform", + is_deprecated=True, inputs=[ IO.Image.Input("image"), IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), @@ -47,6 +48,57 @@ class ImageCrop(IO.ComfyNode): crop = execute # TODO: remove +class ImageCropV2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCropV2", + search_aliases=["trim"], + display_name="Image Crop", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.BoundingBox.Input("crop_region", component="ImageCrop"), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, crop_region) -> IO.NodeOutput: + x = crop_region.get("x", 0) + y = crop_region.get("y", 0) + width = crop_region.get("width", 512) + height = crop_region.get("height", 512) + + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return IO.NodeOutput(img, ui=UI.PreviewImage(img)) + + +class BoundingBox(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PrimitiveBoundingBox", + display_name="Bounding Box", + category="utils/primitive", + inputs=[ + IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), + IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), + IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION), + IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION), + ], + outputs=[IO.BoundingBox.Output()], + ) + + @classmethod + def execute(cls, x, y, width, height) -> IO.NodeOutput: + return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height}) + + class RepeatImageBatch(IO.ComfyNode): @classmethod def define_schema(cls): @@ -632,6 +684,8 @@ class ImagesExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ ImageCrop, + ImageCropV2, + BoundingBox, RepeatImageBatch, ImageFromBatch, ImageAddNoise, From 83dd65f23ae78186df5be7f579af5c0cdb61f0f9 Mon Sep 17 00:00:00 2001 From: Hunter Date: Wed, 18 Feb 2026 00:03:54 -0500 Subject: [PATCH 129/215] fix: use glob matching for Gemini image MIME types (#12511) gemini-3-pro-image-preview nondeterministically returns image/jpeg instead of image/png. get_image_from_response() hardcoded get_parts_by_type(response, "image/png"), silently dropping JPEG responses and falling back to torch.zeros (all-black output). Add _mime_matches() helper using fnmatch for glob-style MIME matching. Change get_image_from_response() to request "image/*" so any image format returned by the API is correctly captured. --- comfy_api_nodes/nodes_gemini.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3b31caa7b..5287a777a 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer import base64 import os from enum import Enum +from fnmatch import fnmatch from io import BytesIO from typing import Literal @@ -119,6 +120,13 @@ async def create_image_parts( return image_parts +def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool: + """Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*').""" + if mime is None: + return False + return fnmatch(mime.value, pattern) + + def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: """ Filter response parts by their type. @@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera for part in candidate.content.parts: if part_type == "text" and part.text: parts.append(part) - elif part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type): parts.append(part) - elif part.fileData and part.fileData.mimeType == part_type: + elif part.fileData and _mime_matches(part.fileData.mimeType, part_type): parts.append(part) if not parts and blocked_reasons: @@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] - parts = get_parts_by_type(response, "image/png") + parts = get_parts_by_type(response, "image/*") for part in parts: if part.inlineData: image_data = base64.b64decode(part.inlineData.data) From 239ddd332724c63934bf517cfc6d0026214d8aee Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:15:23 +0200 Subject: [PATCH 130/215] fix(api-nodes): add price badge for Rodin Gen-2 node (#12512) --- comfy_api_nodes/nodes_rodin.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index f9cff121f..9c1adaa51 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -505,6 +505,9 @@ class Rodin3D_Gen2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod From f262444dd4818b6acdbc1350856679dd6245f7f5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:36:35 -0800 Subject: [PATCH 131/215] Add simple 3 band equalizer node for audio. (#12519) --- comfy_extras/nodes_audio.py | 62 +++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index b63dd8e97..7e74169f2 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -698,6 +698,67 @@ class EmptyAudio(IO.ComfyNode): create_empty_audio = execute # TODO: remove +class AudioEqualizer3Band(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioEqualizer3Band", + search_aliases=["eq", "bass boost", "treble boost", "equalizer"], + display_name="Audio Equalizer (3-Band)", + category="audio", + is_experimental=True, + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"), + IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"), + IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"), + IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"), + IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"), + IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"), + IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + eq_waveform = waveform.clone() + + # 1. Apply Low Shelf (Bass) + if low_gain_dB != 0: + eq_waveform = torchaudio.functional.bass_biquad( + eq_waveform, + sample_rate, + gain=low_gain_dB, + central_freq=float(low_freq), + Q=0.707 + ) + + # 2. Apply Peaking EQ (Mids) + if mid_gain_dB != 0: + eq_waveform = torchaudio.functional.equalizer_biquad( + eq_waveform, + sample_rate, + center_freq=float(mid_freq), + gain=mid_gain_dB, + Q=mid_q + ) + + # 3. Apply High Shelf (Treble) + if high_gain_dB != 0: + eq_waveform = torchaudio.functional.treble_biquad( + eq_waveform, + sample_rate, + gain=high_gain_dB, + central_freq=float(high_freq), + Q=0.707 + ) + + return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate}) + + class AudioExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -720,6 +781,7 @@ class AudioExtension(ComfyExtension): AudioMerge, AudioAdjustVolume, EmptyAudio, + AudioEqualizer3Band, ] async def comfy_entrypoint() -> AudioExtension: From 6d11cc73549e14a0a31e9ff8c90bfd71b380fe2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 19 Feb 2026 03:49:43 +0200 Subject: [PATCH 132/215] feat: Add basic text generation support with native models, initially supporting Gemma3 (#12392) --- comfy/sd.py | 29 +++- comfy/sd1_clip.py | 18 +++ comfy/text_encoders/llama.py | 148 +++++++++++++++++++- comfy/text_encoders/lt.py | 92 ++++++++++--- comfy/text_encoders/lumina2.py | 36 ++++- comfy/text_encoders/spiece_tokenizer.py | 27 +++- comfy/utils.py | 8 ++ comfy_extras/nodes_textgen.py | 176 ++++++++++++++++++++++++ nodes.py | 1 + 9 files changed, 502 insertions(+), 33 deletions(-) create mode 100644 comfy_extras/nodes_textgen.py diff --git a/comfy/sd.py b/comfy/sd.py index f65e7cadd..164f30803 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -423,6 +423,19 @@ class CLIP: def get_key_patches(self): return self.patcher.get_key_patches() + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + self.cond_stage_model.reset_clip_options() + + if self.layer_idx is not None: + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + + self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) + + def decode(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format @@ -1182,6 +1195,7 @@ class TEModel(Enum): JINA_CLIP_2 = 19 QWEN3_8B = 20 QWEN3_06B = 21 + GEMMA_3_4B_VISION = 22 def detect_te_model(sd): @@ -1210,7 +1224,10 @@ def detect_te_model(sd): if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.GEMMA_3_4B + if 'vision_model.embeddings.patch_embedding.weight' in sd: + return TEModel.GEMMA_3_4B_VISION + else: + return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] @@ -1270,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: if "text_projection" in clip_data[i]: clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node + if "lm_head.weight" in clip_data[i]: + clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models tokenizer_data = {} clip_target = EmptyClass() @@ -1335,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_4B_VISION: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision") + clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_12B: + clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b564d1529..d9d014055 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -308,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def load_sd(self, sd): return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]): + if isinstance(tokens, dict): + tokens_only = next(iter(tokens.values())) # todo: get this better? + else: + tokens_only = tokens + tokens_only = [[t[0] for t in b] for b in tokens_only] + embeds = self.process_tokens(tokens_only, device=self.execution_device)[0] + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens) + def parse_parentheses(string): result = [] current_item = "" @@ -663,6 +672,9 @@ class SDTokenizer: def state_dict(self): return {} + def decode(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + class SD1Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None): if name is not None: @@ -686,6 +698,9 @@ class SD1Tokenizer: def state_dict(self): return getattr(self, self.clip).state_dict() + def decode(self, token_ids, skip_special_tokens=True): + return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens) + class SD1CheckpointClipModel(SDClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) @@ -722,3 +737,6 @@ class SD1ClipModel(torch.nn.Module): def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) + + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 54f3d5595..e5d21fa74 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,6 +3,8 @@ import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any, Tuple import math +from tqdm import tqdm +import comfy.utils from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -313,6 +315,13 @@ class Gemma3_4B_Config: final_norm: bool = True lm_head: bool = False +GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + +@dataclass +class Gemma3_4B_Vision_Config(Gemma3_4B_Config): + vision_config = GEMMA3_VISION_CONFIG + mm_tokens_per_image = 256 + @dataclass class Gemma3_12B_Config: vocab_size: int = 262208 @@ -336,7 +345,7 @@ class Gemma3_12B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True lm_head: bool = False - vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + vision_config = GEMMA3_VISION_CONFIG mm_tokens_per_image = 256 class RMSNorm(nn.Module): @@ -441,8 +450,10 @@ class Attention(nn.Module): freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + sliding_window: Optional[int] = None, ): batch_size, seq_length, _ = hidden_states.shape + xq = self.q_proj(hidden_states) xk = self.k_proj(hidden_states) xv = self.v_proj(hidden_states) @@ -477,6 +488,11 @@ class Attention(nn.Module): else: present_key_value = (xk, xv, index + num_tokens) + if sliding_window is not None and xk.shape[2] > sliding_window: + xk = xk[:, :, -sliding_window:] + xv = xv[:, :, -sliding_window:] + attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) @@ -559,10 +575,12 @@ class TransformerBlockGemma2(nn.Module): optimized_attention=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): + sliding_window = None if self.transformer_type == 'gemma3': if self.sliding_attention: + sliding_window = self.sliding_attention if x.shape[1] > self.sliding_attention: - sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) + sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype) sliding_mask.tril_(diagonal=-self.sliding_attention) if attention_mask is not None: attention_mask = attention_mask + sliding_mask @@ -581,6 +599,7 @@ class TransformerBlockGemma2(nn.Module): freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value, + sliding_window=sliding_window, ) x = self.post_attention_layernorm(x) @@ -765,6 +784,104 @@ class BaseLlama: def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) +class BaseGenerate: + def logits(self, x): + input = x[:, -1:] + if hasattr(self.model, "lm_head"): + module = self.model.lm_head + else: + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x + + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0): + device = embeds.device + model_config = self.model.config + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + embeds = embeds.to(execution_dtype) + + if embeds.ndim == 2: + embeds = embeds.unsqueeze(0) + + past_key_values = [] #kv_cache init + max_cache_len = embeds.shape[1] + max_length + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), + torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + + generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None + + generated_token_ids = [] + pbar = comfy.utils.ProgressBar(max_length) + + # Generation loop + for step in tqdm(range(max_length), desc="Generating tokens"): + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + logits = self.logits(x)[:, -1] + next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample) + token_id = next_token[0].item() + generated_token_ids.append(token_id) + + embeds = self.model.embed_tokens(next_token).to(execution_dtype) + pbar.update(1) + + if token_id in stop_tokens: + break + + return generated_token_ids + + def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True): + + if not do_sample or temperature == 0.0: + return torch.argmax(logits, dim=-1, keepdim=True) + + # Sampling mode + if repetition_penalty != 1.0: + for i in range(logits.shape[0]): + for token_id in set(token_history): + logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty + + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + if min_p > 0.0: + probs_before_filter = torch.nn.functional.softmax(logits, dim=-1) + top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True) + min_threshold = min_p * top_probs + indices_to_remove = probs_before_filter < min_threshold + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 0] = False + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + probs = torch.nn.functional.softmax(logits, dim=-1) + + return torch.multinomial(probs, num_samples=1, generator=generator) + class BaseQwen3: def logits(self, x): input = x[:, -1:] @@ -871,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen25_7BVLI(BaseLlama, torch.nn.Module): +class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen25_7BVLI_Config(**config_dict) @@ -881,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.dtype = dtype + # todo: should this be tied or not? + #self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + def preprocess_embed(self, embed, device): if embed["type"] == "image": image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) @@ -923,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Gemma3_4B(BaseLlama, torch.nn.Module): +class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma3_4B_Config(**config_dict) @@ -932,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Gemma3_12B(BaseLlama, torch.nn.Module): +class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_4B_Vision_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None + +class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma3_12B_Config(**config_dict) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 9cf87c0b2..82fbacf59 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,6 +6,7 @@ import comfy.text_encoders.genmo from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import torch import comfy.utils +import math class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) -class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) - +class Gemma3_Tokenizer(): def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} + def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs): + self.llama_template = "system\nYou are a helpful assistant.\nuser\n{}\nmodel\n" + self.llama_template_images = "system\nYou are a helpful assistant.\nuser\n\n{}\n\nmodel\n" + + if image is None: + images = [] + else: + samples = image.movedim(-1, 1) + total = int(896 * 896) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + images = [s[:, :, :, :3]] + + if text.startswith(''): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + if len(images) > 0: + embed_count = 0 + for r in text_tokens: + for i, token in enumerate(r): + if token[0] == 262144 and embed_count < len(images): + r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:] + embed_count += 1 + return text_tokens + +class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + special_tokens = {"": 262144, "": 106} + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) + + class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + class Gemma3_12BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata - + self.dtypes = set() + self.dtypes.add(dtype) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) - def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): - text = llama_template.format(text) - text_tokens = super().tokenize_with_weights(text, return_word_ids) - embed_count = 0 - for k in text_tokens: - tt = text_tokens[k] - for r in tt: - for i in range(len(r)): - if r[i][0] == 262144: - if image_embeds is not None and embed_count < image_embeds.shape[0]: - r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] - embed_count += 1 - return text_tokens + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) + comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is class LTXAVTEModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): @@ -112,6 +152,9 @@ class LTXAVTEModel(torch.nn.Module): return out.to(out_device), pooled + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) + def load_sd(self, sd): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) @@ -152,3 +195,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): dtype = dtype_llama super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return LTXAVTEModel_ + +def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): + class Gemma3_12BModel_(Gemma3_12BModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Gemma3_12BModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index b29a7cc87..1b731e094 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -1,23 +1,23 @@ from comfy import sd1_clip from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.llama - +from comfy.text_encoders.lt import Gemma3_Tokenizer +import comfy.utils class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + special_tokens = {"": 107} + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} -class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): +class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) - - def state_dict(self): - return {"spiece_model": self.tokenizer.serialize_model()} + special_tokens = {"": 262144, "": 106} + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data) class LuminaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107]) + class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) @@ -40,6 +43,23 @@ class Gemma3_4BModel(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) + +class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, embeds_info = super().process_tokens(tokens, device) + comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + return embeds + class LuminaModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) @@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b model = Gemma2_2BModel elif model_type == "gemma3_4b": model = Gemma3_4BModel + elif model_type == "gemma3_4b_vision": + model = Gemma3_4B_Vision_Model class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index caccb3ca2..099d8d2d9 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -6,9 +6,10 @@ class SPieceTokenizer: def from_pretrained(path, **kwargs): return SPieceTokenizer(path, **kwargs) - def __init__(self, tokenizer_path, add_bos=False, add_eos=True): + def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None): self.add_bos = add_bos self.add_eos = add_eos + self.special_tokens = special_tokens import sentencepiece if torch.is_tensor(tokenizer_path): tokenizer_path = tokenizer_path.numpy().tobytes() @@ -27,8 +28,32 @@ class SPieceTokenizer: return out def __call__(self, string): + if self.special_tokens is not None: + import re + special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys()) + if special_tokens_pattern and re.search(special_tokens_pattern, string): + parts = re.split(f'({special_tokens_pattern})', string) + result = [] + for part in parts: + if not part: + continue + if part in self.special_tokens: + result.append(self.special_tokens[part]) + else: + encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False) + result.extend(encoded) + return {"input_ids": result} + out = self.tokenizer.encode(string) return {"input_ids": out} + def decode(self, token_ids, skip_special_tokens=False): + + if skip_special_tokens and self.special_tokens: + special_token_ids = set(self.special_tokens.values()) + token_ids = [tid for tid in token_ids if tid not in special_token_ids] + + return self.tokenizer.decode(token_ids) + def serialize_model(self): return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) diff --git a/comfy/utils.py b/comfy/utils.py index c1ce540b5..17443b4cc 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1418,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res + +def normalize_image_embeddings(embeds, embeds_info, scale_factor): + """Normalize image embeddings to match text embedding scale""" + for info in embeds_info: + if info.get("type") == "image": + start_idx = info["index"] + end_idx = start_idx + info["size"] + embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py new file mode 100644 index 000000000..dd4f6b0d3 --- /dev/null +++ b/comfy_extras/nodes_textgen.py @@ -0,0 +1,176 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + +class TextGenerate(io.ComfyNode): + @classmethod + def define_schema(cls): + # Define dynamic combo options for sampling mode + sampling_options = [ + io.DynamicCombo.Option( + key="on", + inputs=[ + io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001), + io.Int.Input("top_k", default=64, min=0, max=1000), + io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01), + io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01), + io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), + ] + ), + io.DynamicCombo.Option( + key="off", + inputs=[] + ), + ] + + return io.Schema( + node_id="TextGenerate", + category="textgen/", + search_aliases=["LLM", "gemma"], + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), + io.Image.Input("image", optional=True), + io.Int.Input("max_length", default=256, min=1, max=2048), + io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), + ], + outputs=[ + io.String.Output(display_name="generated_text"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput: + + tokens = clip.tokenize(prompt, image=image, skip_template=False) + + # Get sampling parameters from dynamic combo + do_sample = sampling_mode.get("sampling_mode") == "on" + temperature = sampling_mode.get("temperature", 1.0) + top_k = sampling_mode.get("top_k", 50) + top_p = sampling_mode.get("top_p", 1.0) + min_p = sampling_mode.get("min_p", 0.0) + seed = sampling_mode.get("seed", None) + repetition_penalty = sampling_mode.get("repetition_penalty", 1.0) + + generated_ids = clip.generate( + tokens, + do_sample=do_sample, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + min_p=min_p, + repetition_penalty=repetition_penalty, + seed=seed + ) + + generated_text = clip.decode(generated_ids, skip_special_tokens=True) + return io.NodeOutput(generated_text) + + +LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model. +#### Guidelines +- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio). + - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc. + - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters. +- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements. +- Maintain chronological flow: use temporal connectors ("as," "then," "while"). +- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present"). +- Speech (only when requested): + - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'"). + - Specify language if not English and accent if relevant. +- Style: Include visual style at the beginning: "Style: