diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 92b1acbd5..57126fa4a 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict): """COMBO type only. Specifies the configuration for a multi-select widget. Available after ComfyUI frontend v1.13.4 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" - gradient_stops: NotRequired[list[list[float]]] - """Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``).""" + gradient_stops: NotRequired[list[dict]] + """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}.""" class HiddenInputTypeDict(TypedDict): diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index e20d498f8..e28d704b4 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor * m_mult else: for d in modulation_dims: - tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1] if m_add is not None: - tensor[:, d[0]:d[1]] += m_add[:, d[2]] + tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1] return tensor diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 00f12c031..8e7912e6d 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -44,6 +44,22 @@ class FluxParams: txt_norm: bool = False +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + class Flux(nn.Module): """ Transformer model for flow matching on sequences. @@ -138,6 +154,7 @@ class Flux(nn.Module): y: Tensor, guidance: Tensor = None, control = None, + timestep_zero_index=None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: @@ -164,10 +181,6 @@ class Flux(nn.Module): txt = self.txt_norm(txt) txt = self.txt_in(txt) - vec_orig = vec - if self.params.global_modulation: - vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) - if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) @@ -182,6 +195,24 @@ class Flux(nn.Module): else: pe = None + vec_orig = vec + txt_vec = vec + extra_kwargs = {} + if timestep_zero_index is not None: + modulation_dims = [] + batch = vec.shape[0] // 2 + vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1) + invert = invert_slices(timestep_zero_index, img.shape[1]) + for s in invert: + modulation_dims.append((s[0], s[1], 0)) + for s in timestep_zero_index: + modulation_dims.append((s[0], s[1], 1)) + extra_kwargs["modulation_dims_img"] = modulation_dims + txt_vec = vec[:batch] + + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec)) + blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" @@ -195,7 +226,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("double_block", i)]({"img": img, @@ -213,7 +245,8 @@ class Flux(nn.Module): vec=vec, pe=pe, attn_mask=attn_mask, - transformer_options=transformer_options) + transformer_options=transformer_options, + **extra_kwargs) if control is not None: # Controlnet control_i = control.get("input") @@ -230,6 +263,12 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + extra_kwargs = {} + if timestep_zero_index is not None: + lambda a: 0 if a == 0 else a + txt.shape[1] + modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims)) + extra_kwargs["modulation_dims"] = modulation_dims_combined + transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] @@ -242,7 +281,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("single_block", i)]({"img": img, @@ -253,7 +293,7 @@ class Flux(nn.Module): {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs) if control is not None: # Controlnet control_o = control.get("output") @@ -264,7 +304,11 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) + extra_kwargs = {} + if timestep_zero_index is not None: + extra_kwargs["modulation_dims"] = modulation_dims + + img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -312,13 +356,16 @@ class Flux(nn.Module): w_len = ((w_orig + (patch_size // 2)) // patch_size) img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] + timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) + timestep_zero = ref_latents_method == "index_timestep_zero" for ref in ref_latents: - if ref_latents_method == "index": + if ref_latents_method in ("index", "index_timestep_zero"): index += self.params.ref_index_scale h_offset = 0 w_offset = 0 @@ -342,6 +389,13 @@ class Flux(nn.Module): kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = [[img_tokens, img_ids.shape[1]]] + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) @@ -349,6 +403,6 @@ class Flux(nn.Module): for i in self.params.txt_ids_dims: txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 6eb744286..0862f72f7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -149,6 +149,9 @@ class Attention(nn.Module): seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # Project and reshape to BHND format (batch, heads, seq, dim) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() @@ -167,15 +170,22 @@ class Attention(nn.Module): joint_key = torch.cat([txt_key, img_key], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) - if encoder_hidden_states_mask is not None: attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask else: attn_mask = None + extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options) + joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attn_mask, transformer_options=transformer_options, skip_reshape=True) @@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module): timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 @@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) if timestep_zero: if index > 0: timestep = torch.cat([timestep, timestep * 0], dim=0) timestep_zero_index = num_embeds + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): diff --git a/comfy/model_management.py b/comfy/model_management.py index 81550c790..81c89b180 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,10 +270,15 @@ try: except: OOM_EXCEPTION = Exception +try: + ACCELERATOR_ERROR = torch.AcceleratorError +except AttributeError: + ACCELERATOR_ERROR = RuntimeError + def is_oom(e): if isinstance(e, OOM_EXCEPTION): return True - if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()): discard_cuda_async_error() return True return False @@ -1275,7 +1280,7 @@ def discard_cuda_async_error(): b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b synchronize() - except torch.AcceleratorError: + except RuntimeError: #Dump it! We already know about it from the synchronous return pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 745384271..bc3a8f446 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -599,6 +599,27 @@ class ModelPatcher: return models + def model_patches_call_function(self, function_name="cleanup", arguments={}): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], function_name): + getattr(patch_list[i], function_name)(**arguments) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], function_name): + getattr(patch_list[k], function_name)(**arguments) + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, function_name): + getattr(wrap_func, function_name)(**arguments) + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -1062,6 +1083,7 @@ class ModelPatcher: return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): + self.model_patches_call_function(function_name="cleanup") self.clean_hooks() if hasattr(self.model, "current_patcher"): self.model.current_patcher = None diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 050031dc0..7ca8f4e0c 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -297,7 +297,7 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None, + display_mode: NumberDisplay=None, gradient_stops: list[dict]=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py new file mode 100644 index 000000000..c6b5a69d8 --- /dev/null +++ b/comfy_api_nodes/apis/reve.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, Field + + +class RevePostprocessingOperation(BaseModel): + process: str = Field(..., description="The postprocessing operation: upscale or remove_background.") + upscale_factor: int | None = Field( + None, + description="Upscale factor (2, 3, or 4). Only used when process is upscale.", + ge=2, + le=4, + ) + + +class ReveImageCreateRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageEditRequest(BaseModel): + edit_instruction: str = Field(...) + reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageRemixRequest(BaseModel): + prompt: str = Field(...) + reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageResponse(BaseModel): + image: str | None = Field(None, description="The base64 encoded image data.") + request_id: str | None = Field(None, description="A unique id for the request.") + credits_used: float | None = Field(None, description="The number of credits used for this request.") + version: str | None = Field(None, description="The specific model version used.") + content_violation: bool | None = Field( + None, description="Indicates whether the generated image violates the content policy." + ) diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py new file mode 100644 index 000000000..608d9f058 --- /dev/null +++ b/comfy_api_nodes/nodes_reve.py @@ -0,0 +1,395 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.reve import ( + ReveImageCreateRequest, + ReveImageEditRequest, + ReveImageRemixRequest, + RevePostprocessingOperation, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + sync_op_raw, + tensor_to_base64_string, + validate_string, +) + + +def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None: + ops = [] + if upscale["upscale"] == "enabled": + ops.append( + RevePostprocessingOperation( + process="upscale", + upscale_factor=upscale["upscale_factor"], + ) + ) + if remove_background: + ops.append(RevePostprocessingOperation(process="remove_background")) + return ops or None + + +def _postprocessing_inputs(): + return [ + IO.DynamicCombo.Input( + "upscale", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "upscale_factor", + default=2, + min=2, + max=4, + step=1, + tooltip="Upscale factor (2x, 3x, or 4x).", + ), + ], + ), + ], + tooltip="Upscale the generated image. May add additional cost.", + ), + IO.Boolean.Input( + "remove_background", + default=False, + tooltip="Remove the background from the generated image. May add additional cost.", + ), + ] + + +def _reve_price_extractor(headers: dict) -> float | None: + credits_used = headers.get("x-reve-credits-used") + if credits_used is not None: + return float(credits_used) / 524.48 + return None + + +def _reve_response_header_validator(headers: dict) -> None: + error_code = headers.get("x-reve-error-code") + if error_code: + raise ValueError(f"Reve API error: {error_code}") + if headers.get("x-reve-content-violation", "").lower() == "true": + raise ValueError("The generated image was flagged for content policy violation.") + + +def _model_inputs(versions: list[str], aspect_ratios: list[str]): + return [ + IO.DynamicCombo.Option( + version, + [ + IO.Combo.Input( + "aspect_ratio", + options=aspect_ratios, + tooltip="Aspect ratio of the output image.", + ), + IO.Int.Input( + "test_time_scaling", + default=1, + min=1, + max=5, + step=1, + tooltip="Higher values produce better images but cost more credits.", + advanced=True, + ), + ], + ) + for version in versions + ] + + +class ReveImageCreateNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageCreateNode", + display_name="Reve Image Create", + category="api node/image/Reve", + description="Generate images from text descriptions using Reve.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-create@20250915"], + aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for generation.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/create", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageCreateRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + version=model["model"], + test_time_scaling=model["test_time_scaling"], + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageEditNode", + display_name="Reve Image Edit", + category="api node/image/Reve", + description="Edit images using natural language instructions with Reve.", + inputs=[ + IO.Image.Input("image", tooltip="The image to edit."), + IO.String.Input( + "edit_instruction", + multiline=True, + default="", + tooltip="Text description of how to edit the image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-edit@20250915", "reve-edit-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for editing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + edit_instruction: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(edit_instruction, min_length=1, max_length=2560) + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/edit", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageEditRequest( + edit_instruction=edit_instruction, + reference_image=tensor_to_base64_string(image), + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageRemixNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageRemixNode", + display_name="Reve Image Remix", + category="api node/image/Reve", + description="Combine reference images with text prompts to create new images using Reve.", + inputs=[ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image_", + min=1, + max=6, + ), + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. " + "May include XML img tags to reference specific images by index, " + "e.g. 0, 1, etc.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-remix@20250915", "reve-remix-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for remixing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + reference_images: IO.Autogrow.Type, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + if not reference_images: + raise ValueError("At least one reference image is required.") + ref_base64_list = [] + for key in reference_images: + ref_base64_list.append(tensor_to_base64_string(reference_images[key])) + if len(ref_base64_list) > 6: + raise ValueError("Maximum 6 reference images are allowed.") + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/remix", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageRemixRequest( + prompt=prompt, + reference_images=ref_base64_list, + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ReveImageCreateNode, + ReveImageEditNode, + ReveImageRemixNode, + ] + + +async def comfy_entrypoint() -> ReveExtension: + return ReveExtension() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 79ffb77c1..9d730b81a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -67,6 +67,7 @@ class _RequestConfig: progress_origin_ts: float | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None + response_header_validator: Callable[[dict[str, str]], None] | None = None @dataclass @@ -202,11 +203,13 @@ async def sync_op_raw( monitor_progress: bool = True, max_retries_on_rate_limit: int = 16, is_rate_limited: Callable[[int, Any], bool] | None = None, + response_header_validator: Callable[[dict[str, str]], None] | None = None, ) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). - If as_binary=True: returns bytes. + - response_header_validator: optional callback receiving response headers dict """ if isinstance(data, BaseModel): data = data.model_dump(exclude_none=True) @@ -232,6 +235,7 @@ async def sync_op_raw( price_extractor=price_extractor, max_retries_on_rate_limit=max_retries_on_rate_limit, is_rate_limited=is_rate_limited, + response_header_validator=response_header_validator, ) return await _request_base(cfg, expect_binary=as_binary) @@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total ) bytes_payload = bytes(buff) + resp_headers = {k.lower(): v for k, v in resp.headers.items()} + if cfg.price_extractor: + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(resp_headers) + if cfg.response_header_validator: + cfg.response_header_validator(resp_headers) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): request_method=method, request_url=url, response_status_code=resp.status, - response_headers=dict(resp.headers), + response_headers=resp_headers, response_content=bytes_payload, ) return bytes_payload diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index fe9552022..3a23c7d04 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -6,6 +6,7 @@ import comfy.model_management import torch import math import nodes +import comfy.ldm.flux.math class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode): sigmas = get_schedule(steps, round(seq_len)) return io.NodeOutput(sigmas) +class KV_Attn_Input: + def __init__(self): + self.cache = {} + + def __call__(self, q, k, v, extra_options, **kwargs): + reference_image_num_tokens = extra_options.get("reference_image_num_tokens", []) + if len(reference_image_num_tokens) == 0: + return {} + + ref_toks = sum(reference_image_num_tokens) + cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) + if cache_key in self.cache: + kk, vv = self.cache[cache_key] + self.set_cache = False + return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} + + self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone()) + self.set_cache = True + return {"q": q, "k": k, "v": v} + + def cleanup(self): + self.cache = {} + + +class FluxKVCache(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FluxKVCache", + display_name="Flux KV Cache", + description="Enables KV Cache optimization for reference images on Flux family models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to use KV Cache on."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with KV Cache enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + m = model.clone() + input_patch_obj = KV_Attn_Input() + + def model_input_patch(inputs): + if len(input_patch_obj.cache) > 0: + ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", [])) + if ref_image_tokens > 0: + img = inputs["img"] + inputs["img"] = img[:, :-ref_image_tokens] + return inputs + + m.set_model_attn1_patch(input_patch_obj) + m.set_model_post_input_patch(model_input_patch) + if hasattr(model.model.diffusion_model, "params"): + m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero") + else: + m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero") + + return io.NodeOutput(m) class FluxExtension(ComfyExtension): @override @@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension): FluxKontextMultiReferenceLatentMethod, EmptyFlux2LatentImage, Flux2Scheduler, + FluxKVCache, ] diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py new file mode 100644 index 000000000..b9ecdf5ea --- /dev/null +++ b/comfy_extras/nodes_painter.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import hashlib +import os + +import numpy as np +import torch +from PIL import Image + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io, UI +from typing_extensions import override + + +def hex_to_rgb(hex_color: str) -> tuple[float, float, float]: + hex_color = hex_color.lstrip("#") + if len(hex_color) != 6: + return (0.0, 0.0, 0.0) + r = int(hex_color[0:2], 16) / 255.0 + g = int(hex_color[2:4], 16) / 255.0 + b = int(hex_color[4:6], 16) / 255.0 + return (r, g, b) + + +class PainterNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Painter", + display_name="Painter", + category="image", + inputs=[ + io.Image.Input( + "image", + optional=True, + tooltip="Optional base image to paint over", + ), + io.String.Input( + "mask", + default="", + socketless=True, + extra_dict={"widgetType": "PAINTER", "image_upload": True}, + ), + io.Int.Input( + "width", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Int.Input( + "height", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Color.Input("bg_color", default="#000000"), + ], + outputs=[ + io.Image.Output("IMAGE"), + io.Mask.Output("MASK"), + ], + ) + + @classmethod + def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput: + if image is not None: + base_image = image[:1] + h, w = base_image.shape[1], base_image.shape[2] + else: + h, w = height, width + r, g, b = hex_to_rgb(bg_color) + base_image = torch.zeros((1, h, w, 3), dtype=torch.float32) + base_image[0, :, :, 0] = r + base_image[0, :, :, 1] = g + base_image[0, :, :, 2] = b + + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + painter_img = node_helpers.pillow(Image.open, mask_path) + painter_img = painter_img.convert("RGBA") + + if painter_img.size != (w, h): + painter_img = painter_img.resize((w, h), Image.LANCZOS) + + painter_np = np.array(painter_img).astype(np.float32) / 255.0 + painter_rgb = painter_np[:, :, :3] + painter_alpha = painter_np[:, :, 3:4] + + mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0) + + base_np = base_image[0].cpu().numpy() + composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha) + out_image = torch.from_numpy(composited).unsqueeze(0) + else: + mask_tensor = torch.zeros((1, h, w), dtype=torch.float32) + out_image = base_image + + return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image)) + + @classmethod + def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None): + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + if os.path.exists(mask_path): + m = hashlib.sha256() + with open(mask_path, "rb") as f: + m.update(f.read()) + return m.digest().hex() + return "" + + + +class PainterExtension(ComfyExtension): + @override + async def get_node_list(self): + return [PainterNode] + + +async def comfy_entrypoint(): + return PainterExtension() diff --git a/nodes.py b/nodes.py index 0ef23b640..eb63f9d44 100644 --- a/nodes.py +++ b/nodes.py @@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes(): "nodes_nag.py", "nodes_sdpose.py", "nodes_math.py", + "nodes_painter.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index bb58f8d01..511c62fee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.18 +comfyui-frontend-package==1.41.18 +comfyui-workflow-templates==0.9.21 comfyui-embedded-docs==0.4.3 torch torchsde @@ -22,8 +22,8 @@ alembic SQLAlchemy filelock av>=14.2.0 -comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.9 +comfy-kitchen>=0.2.8 +comfy-aimdo>=0.2.10 requests simpleeval>=1.0.0 blake3