mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-14 17:47:30 +00:00
Compare commits
1 Commits
fix/gradie
...
v3/model_m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1c07a72c4 |
@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
Available after ComfyUI frontend v1.13.4
|
Available after ComfyUI frontend v1.13.4
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
gradient_stops: NotRequired[list[dict]]
|
gradient_stops: NotRequired[list[list[float]]]
|
||||||
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
|
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
|
|||||||
@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
|||||||
return tensor * m_mult
|
return tensor * m_mult
|
||||||
else:
|
else:
|
||||||
for d in modulation_dims:
|
for d in modulation_dims:
|
||||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
|
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||||
if m_add is not None:
|
if m_add is not None:
|
||||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
|
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,22 +44,6 @@ class FluxParams:
|
|||||||
txt_norm: bool = False
|
txt_norm: bool = False
|
||||||
|
|
||||||
|
|
||||||
def invert_slices(slices, length):
|
|
||||||
sorted_slices = sorted(slices)
|
|
||||||
result = []
|
|
||||||
current = 0
|
|
||||||
|
|
||||||
for start, end in sorted_slices:
|
|
||||||
if current < start:
|
|
||||||
result.append((current, start))
|
|
||||||
current = max(current, end)
|
|
||||||
|
|
||||||
if current < length:
|
|
||||||
result.append((current, length))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
class Flux(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
@@ -154,7 +138,6 @@ class Flux(nn.Module):
|
|||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
control = None,
|
control = None,
|
||||||
timestep_zero_index=None,
|
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@@ -181,6 +164,10 @@ class Flux(nn.Module):
|
|||||||
txt = self.txt_norm(txt)
|
txt = self.txt_norm(txt)
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
vec_orig = vec
|
||||||
|
if self.params.global_modulation:
|
||||||
|
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||||
|
|
||||||
if "post_input" in patches:
|
if "post_input" in patches:
|
||||||
for p in patches["post_input"]:
|
for p in patches["post_input"]:
|
||||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||||
@@ -195,24 +182,6 @@ class Flux(nn.Module):
|
|||||||
else:
|
else:
|
||||||
pe = None
|
pe = None
|
||||||
|
|
||||||
vec_orig = vec
|
|
||||||
txt_vec = vec
|
|
||||||
extra_kwargs = {}
|
|
||||||
if timestep_zero_index is not None:
|
|
||||||
modulation_dims = []
|
|
||||||
batch = vec.shape[0] // 2
|
|
||||||
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
|
|
||||||
invert = invert_slices(timestep_zero_index, img.shape[1])
|
|
||||||
for s in invert:
|
|
||||||
modulation_dims.append((s[0], s[1], 0))
|
|
||||||
for s in timestep_zero_index:
|
|
||||||
modulation_dims.append((s[0], s[1], 1))
|
|
||||||
extra_kwargs["modulation_dims_img"] = modulation_dims
|
|
||||||
txt_vec = vec[:batch]
|
|
||||||
|
|
||||||
if self.params.global_modulation:
|
|
||||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
|
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
@@ -226,8 +195,7 @@ class Flux(nn.Module):
|
|||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"),
|
attn_mask=args.get("attn_mask"),
|
||||||
transformer_options=args.get("transformer_options"),
|
transformer_options=args.get("transformer_options"))
|
||||||
**extra_kwargs)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
@@ -245,8 +213,7 @@ class Flux(nn.Module):
|
|||||||
vec=vec,
|
vec=vec,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options)
|
||||||
**extra_kwargs)
|
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -263,12 +230,6 @@ class Flux(nn.Module):
|
|||||||
if self.params.global_modulation:
|
if self.params.global_modulation:
|
||||||
vec, _ = self.single_stream_modulation(vec_orig)
|
vec, _ = self.single_stream_modulation(vec_orig)
|
||||||
|
|
||||||
extra_kwargs = {}
|
|
||||||
if timestep_zero_index is not None:
|
|
||||||
lambda a: 0 if a == 0 else a + txt.shape[1]
|
|
||||||
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
|
|
||||||
extra_kwargs["modulation_dims"] = modulation_dims_combined
|
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
@@ -281,8 +242,7 @@ class Flux(nn.Module):
|
|||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"),
|
attn_mask=args.get("attn_mask"),
|
||||||
transformer_options=args.get("transformer_options"),
|
transformer_options=args.get("transformer_options"))
|
||||||
**extra_kwargs)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
@@ -293,7 +253,7 @@ class Flux(nn.Module):
|
|||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
@@ -304,11 +264,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
extra_kwargs = {}
|
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||||
if timestep_zero_index is not None:
|
|
||||||
extra_kwargs["modulation_dims"] = modulation_dims
|
|
||||||
|
|
||||||
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||||
@@ -356,16 +312,13 @@ class Flux(nn.Module):
|
|||||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||||
img_tokens = img.shape[1]
|
img_tokens = img.shape[1]
|
||||||
timestep_zero_index = None
|
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
ref_num_tokens = []
|
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||||
timestep_zero = ref_latents_method == "index_timestep_zero"
|
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if ref_latents_method in ("index", "index_timestep_zero"):
|
if ref_latents_method == "index":
|
||||||
index += self.params.ref_index_scale
|
index += self.params.ref_index_scale
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
w_offset = 0
|
w_offset = 0
|
||||||
@@ -389,13 +342,6 @@ class Flux(nn.Module):
|
|||||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
img = torch.cat([img, kontext], dim=1)
|
img = torch.cat([img, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
ref_num_tokens.append(kontext.shape[1])
|
|
||||||
if timestep_zero:
|
|
||||||
if index > 0:
|
|
||||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
|
||||||
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
|
|
||||||
transformer_options = transformer_options.copy()
|
|
||||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
@@ -403,6 +349,6 @@ class Flux(nn.Module):
|
|||||||
for i in self.params.txt_ids_dims:
|
for i in self.params.txt_ids_dims:
|
||||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
out = out[:, :img_tokens]
|
out = out[:, :img_tokens]
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||||
|
|||||||
@@ -270,15 +270,10 @@ try:
|
|||||||
except:
|
except:
|
||||||
OOM_EXCEPTION = Exception
|
OOM_EXCEPTION = Exception
|
||||||
|
|
||||||
try:
|
|
||||||
ACCELERATOR_ERROR = torch.AcceleratorError
|
|
||||||
except AttributeError:
|
|
||||||
ACCELERATOR_ERROR = RuntimeError
|
|
||||||
|
|
||||||
def is_oom(e):
|
def is_oom(e):
|
||||||
if isinstance(e, OOM_EXCEPTION):
|
if isinstance(e, OOM_EXCEPTION):
|
||||||
return True
|
return True
|
||||||
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
|
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
|
||||||
discard_cuda_async_error()
|
discard_cuda_async_error()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@@ -1280,7 +1275,7 @@ def discard_cuda_async_error():
|
|||||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
_ = a + b
|
_ = a + b
|
||||||
synchronize()
|
synchronize()
|
||||||
except RuntimeError:
|
except torch.AcceleratorError:
|
||||||
#Dump it! We already know about it from the synchronous return
|
#Dump it! We already know about it from the synchronous return
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
|
|||||||
'''Float input.'''
|
'''Float input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
|
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
|
||||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||||
self.min = min
|
self.min = min
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class RevePostprocessingOperation(BaseModel):
|
|
||||||
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
|
|
||||||
upscale_factor: int | None = Field(
|
|
||||||
None,
|
|
||||||
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
|
|
||||||
ge=2,
|
|
||||||
le=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageCreateRequest(BaseModel):
|
|
||||||
prompt: str = Field(...)
|
|
||||||
aspect_ratio: str | None = Field(...)
|
|
||||||
version: str = Field(...)
|
|
||||||
test_time_scaling: int = Field(
|
|
||||||
...,
|
|
||||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
|
||||||
ge=1,
|
|
||||||
le=15,
|
|
||||||
)
|
|
||||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
|
||||||
None, description="Optional postprocessing operations to apply after generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageEditRequest(BaseModel):
|
|
||||||
edit_instruction: str = Field(...)
|
|
||||||
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
|
|
||||||
aspect_ratio: str | None = Field(...)
|
|
||||||
version: str = Field(...)
|
|
||||||
test_time_scaling: int | None = Field(
|
|
||||||
...,
|
|
||||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
|
||||||
ge=1,
|
|
||||||
le=15,
|
|
||||||
)
|
|
||||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
|
||||||
None, description="Optional postprocessing operations to apply after generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageRemixRequest(BaseModel):
|
|
||||||
prompt: str = Field(...)
|
|
||||||
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
|
|
||||||
aspect_ratio: str | None = Field(...)
|
|
||||||
version: str = Field(...)
|
|
||||||
test_time_scaling: int | None = Field(
|
|
||||||
...,
|
|
||||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
|
||||||
ge=1,
|
|
||||||
le=15,
|
|
||||||
)
|
|
||||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
|
||||||
None, description="Optional postprocessing operations to apply after generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageResponse(BaseModel):
|
|
||||||
image: str | None = Field(None, description="The base64 encoded image data.")
|
|
||||||
request_id: str | None = Field(None, description="A unique id for the request.")
|
|
||||||
credits_used: float | None = Field(None, description="The number of credits used for this request.")
|
|
||||||
version: str | None = Field(None, description="The specific model version used.")
|
|
||||||
content_violation: bool | None = Field(
|
|
||||||
None, description="Indicates whether the generated image violates the content policy."
|
|
||||||
)
|
|
||||||
@@ -1,395 +0,0 @@
|
|||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
|
||||||
from comfy_api_nodes.apis.reve import (
|
|
||||||
ReveImageCreateRequest,
|
|
||||||
ReveImageEditRequest,
|
|
||||||
ReveImageRemixRequest,
|
|
||||||
RevePostprocessingOperation,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
|
||||||
ApiEndpoint,
|
|
||||||
bytesio_to_image_tensor,
|
|
||||||
sync_op_raw,
|
|
||||||
tensor_to_base64_string,
|
|
||||||
validate_string,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
|
|
||||||
ops = []
|
|
||||||
if upscale["upscale"] == "enabled":
|
|
||||||
ops.append(
|
|
||||||
RevePostprocessingOperation(
|
|
||||||
process="upscale",
|
|
||||||
upscale_factor=upscale["upscale_factor"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if remove_background:
|
|
||||||
ops.append(RevePostprocessingOperation(process="remove_background"))
|
|
||||||
return ops or None
|
|
||||||
|
|
||||||
|
|
||||||
def _postprocessing_inputs():
|
|
||||||
return [
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"upscale",
|
|
||||||
options=[
|
|
||||||
IO.DynamicCombo.Option("disabled", []),
|
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
"enabled",
|
|
||||||
[
|
|
||||||
IO.Int.Input(
|
|
||||||
"upscale_factor",
|
|
||||||
default=2,
|
|
||||||
min=2,
|
|
||||||
max=4,
|
|
||||||
step=1,
|
|
||||||
tooltip="Upscale factor (2x, 3x, or 4x).",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tooltip="Upscale the generated image. May add additional cost.",
|
|
||||||
),
|
|
||||||
IO.Boolean.Input(
|
|
||||||
"remove_background",
|
|
||||||
default=False,
|
|
||||||
tooltip="Remove the background from the generated image. May add additional cost.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _reve_price_extractor(headers: dict) -> float | None:
|
|
||||||
credits_used = headers.get("x-reve-credits-used")
|
|
||||||
if credits_used is not None:
|
|
||||||
return float(credits_used) / 524.48
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _reve_response_header_validator(headers: dict) -> None:
|
|
||||||
error_code = headers.get("x-reve-error-code")
|
|
||||||
if error_code:
|
|
||||||
raise ValueError(f"Reve API error: {error_code}")
|
|
||||||
if headers.get("x-reve-content-violation", "").lower() == "true":
|
|
||||||
raise ValueError("The generated image was flagged for content policy violation.")
|
|
||||||
|
|
||||||
|
|
||||||
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
|
|
||||||
return [
|
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
version,
|
|
||||||
[
|
|
||||||
IO.Combo.Input(
|
|
||||||
"aspect_ratio",
|
|
||||||
options=aspect_ratios,
|
|
||||||
tooltip="Aspect ratio of the output image.",
|
|
||||||
),
|
|
||||||
IO.Int.Input(
|
|
||||||
"test_time_scaling",
|
|
||||||
default=1,
|
|
||||||
min=1,
|
|
||||||
max=5,
|
|
||||||
step=1,
|
|
||||||
tooltip="Higher values produce better images but cost more credits.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for version in versions
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageCreateNode(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ReveImageCreateNode",
|
|
||||||
display_name="Reve Image Create",
|
|
||||||
category="api node/image/Reve",
|
|
||||||
description="Generate images from text descriptions using Reve.",
|
|
||||||
inputs=[
|
|
||||||
IO.String.Input(
|
|
||||||
"prompt",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of the desired image. Maximum 2560 characters.",
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=_model_inputs(
|
|
||||||
["reve-create@20250915"],
|
|
||||||
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
|
|
||||||
),
|
|
||||||
tooltip="Model version to use for generation.",
|
|
||||||
),
|
|
||||||
*_postprocessing_inputs(),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
default=0,
|
|
||||||
min=0,
|
|
||||||
max=2147483647,
|
|
||||||
control_after_generate=True,
|
|
||||||
tooltip="Seed controls whether the node should re-run; "
|
|
||||||
"results are non-deterministic regardless of seed.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
prompt: str,
|
|
||||||
model: dict,
|
|
||||||
upscale: dict,
|
|
||||||
remove_background: bool,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt, min_length=1, max_length=2560)
|
|
||||||
response = await sync_op_raw(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(
|
|
||||||
path="/proxy/reve/v1/image/create",
|
|
||||||
method="POST",
|
|
||||||
headers={"Accept": "image/webp"},
|
|
||||||
),
|
|
||||||
as_binary=True,
|
|
||||||
price_extractor=_reve_price_extractor,
|
|
||||||
response_header_validator=_reve_response_header_validator,
|
|
||||||
data=ReveImageCreateRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
aspect_ratio=model["aspect_ratio"],
|
|
||||||
version=model["model"],
|
|
||||||
test_time_scaling=model["test_time_scaling"],
|
|
||||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageEditNode(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ReveImageEditNode",
|
|
||||||
display_name="Reve Image Edit",
|
|
||||||
category="api node/image/Reve",
|
|
||||||
description="Edit images using natural language instructions with Reve.",
|
|
||||||
inputs=[
|
|
||||||
IO.Image.Input("image", tooltip="The image to edit."),
|
|
||||||
IO.String.Input(
|
|
||||||
"edit_instruction",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=_model_inputs(
|
|
||||||
["reve-edit@20250915", "reve-edit-fast@20251030"],
|
|
||||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
|
||||||
),
|
|
||||||
tooltip="Model version to use for editing.",
|
|
||||||
),
|
|
||||||
*_postprocessing_inputs(),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
default=0,
|
|
||||||
min=0,
|
|
||||||
max=2147483647,
|
|
||||||
control_after_generate=True,
|
|
||||||
tooltip="Seed controls whether the node should re-run; "
|
|
||||||
"results are non-deterministic regardless of seed.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
depends_on=IO.PriceBadgeDepends(
|
|
||||||
widgets=["model"],
|
|
||||||
),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$isFast := $contains(widgets.model, "fast");
|
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
image: Input.Image,
|
|
||||||
edit_instruction: str,
|
|
||||||
model: dict,
|
|
||||||
upscale: dict,
|
|
||||||
remove_background: bool,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(edit_instruction, min_length=1, max_length=2560)
|
|
||||||
tts = model["test_time_scaling"]
|
|
||||||
ar = model["aspect_ratio"]
|
|
||||||
response = await sync_op_raw(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(
|
|
||||||
path="/proxy/reve/v1/image/edit",
|
|
||||||
method="POST",
|
|
||||||
headers={"Accept": "image/webp"},
|
|
||||||
),
|
|
||||||
as_binary=True,
|
|
||||||
price_extractor=_reve_price_extractor,
|
|
||||||
response_header_validator=_reve_response_header_validator,
|
|
||||||
data=ReveImageEditRequest(
|
|
||||||
edit_instruction=edit_instruction,
|
|
||||||
reference_image=tensor_to_base64_string(image),
|
|
||||||
aspect_ratio=ar if ar != "auto" else None,
|
|
||||||
version=model["model"],
|
|
||||||
test_time_scaling=tts if tts and tts > 1 else None,
|
|
||||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageRemixNode(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ReveImageRemixNode",
|
|
||||||
display_name="Reve Image Remix",
|
|
||||||
category="api node/image/Reve",
|
|
||||||
description="Combine reference images with text prompts to create new images using Reve.",
|
|
||||||
inputs=[
|
|
||||||
IO.Autogrow.Input(
|
|
||||||
"reference_images",
|
|
||||||
template=IO.Autogrow.TemplatePrefix(
|
|
||||||
IO.Image.Input("image"),
|
|
||||||
prefix="image_",
|
|
||||||
min=1,
|
|
||||||
max=6,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
IO.String.Input(
|
|
||||||
"prompt",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of the desired image. "
|
|
||||||
"May include XML img tags to reference specific images by index, "
|
|
||||||
"e.g. <img>0</img>, <img>1</img>, etc.",
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=_model_inputs(
|
|
||||||
["reve-remix@20250915", "reve-remix-fast@20251030"],
|
|
||||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
|
||||||
),
|
|
||||||
tooltip="Model version to use for remixing.",
|
|
||||||
),
|
|
||||||
*_postprocessing_inputs(),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
default=0,
|
|
||||||
min=0,
|
|
||||||
max=2147483647,
|
|
||||||
control_after_generate=True,
|
|
||||||
tooltip="Seed controls whether the node should re-run; "
|
|
||||||
"results are non-deterministic regardless of seed.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
depends_on=IO.PriceBadgeDepends(
|
|
||||||
widgets=["model"],
|
|
||||||
),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$isFast := $contains(widgets.model, "fast");
|
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
reference_images: IO.Autogrow.Type,
|
|
||||||
prompt: str,
|
|
||||||
model: dict,
|
|
||||||
upscale: dict,
|
|
||||||
remove_background: bool,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt, min_length=1, max_length=2560)
|
|
||||||
if not reference_images:
|
|
||||||
raise ValueError("At least one reference image is required.")
|
|
||||||
ref_base64_list = []
|
|
||||||
for key in reference_images:
|
|
||||||
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
|
|
||||||
if len(ref_base64_list) > 6:
|
|
||||||
raise ValueError("Maximum 6 reference images are allowed.")
|
|
||||||
tts = model["test_time_scaling"]
|
|
||||||
ar = model["aspect_ratio"]
|
|
||||||
response = await sync_op_raw(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(
|
|
||||||
path="/proxy/reve/v1/image/remix",
|
|
||||||
method="POST",
|
|
||||||
headers={"Accept": "image/webp"},
|
|
||||||
),
|
|
||||||
as_binary=True,
|
|
||||||
price_extractor=_reve_price_extractor,
|
|
||||||
response_header_validator=_reve_response_header_validator,
|
|
||||||
data=ReveImageRemixRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
reference_images=ref_base64_list,
|
|
||||||
aspect_ratio=ar if ar != "auto" else None,
|
|
||||||
version=model["model"],
|
|
||||||
test_time_scaling=tts if tts and tts > 1 else None,
|
|
||||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReveExtension(ComfyExtension):
|
|
||||||
@override
|
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
||||||
return [
|
|
||||||
ReveImageCreateNode,
|
|
||||||
ReveImageEditNode,
|
|
||||||
ReveImageRemixNode,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> ReveExtension:
|
|
||||||
return ReveExtension()
|
|
||||||
@@ -67,7 +67,6 @@ class _RequestConfig:
|
|||||||
progress_origin_ts: float | None = None
|
progress_origin_ts: float | None = None
|
||||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -203,13 +202,11 @@ async def sync_op_raw(
|
|||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
max_retries_on_rate_limit: int = 16,
|
max_retries_on_rate_limit: int = 16,
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
|
||||||
) -> dict[str, Any] | bytes:
|
) -> dict[str, Any] | bytes:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
- If as_binary=True: returns bytes.
|
- If as_binary=True: returns bytes.
|
||||||
- response_header_validator: optional callback receiving response headers dict
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.model_dump(exclude_none=True)
|
data = data.model_dump(exclude_none=True)
|
||||||
@@ -235,7 +232,6 @@ async def sync_op_raw(
|
|||||||
price_extractor=price_extractor,
|
price_extractor=price_extractor,
|
||||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||||
is_rate_limited=is_rate_limited,
|
is_rate_limited=is_rate_limited,
|
||||||
response_header_validator=response_header_validator,
|
|
||||||
)
|
)
|
||||||
return await _request_base(cfg, expect_binary=as_binary)
|
return await _request_base(cfg, expect_binary=as_binary)
|
||||||
|
|
||||||
@@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||||
)
|
)
|
||||||
bytes_payload = bytes(buff)
|
bytes_payload = bytes(buff)
|
||||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
|
||||||
if cfg.price_extractor:
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
extracted_price = cfg.price_extractor(resp_headers)
|
|
||||||
if cfg.response_header_validator:
|
|
||||||
cfg.response_header_validator(resp_headers)
|
|
||||||
operation_succeeded = True
|
operation_succeeded = True
|
||||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
@@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
request_method=method,
|
request_method=method,
|
||||||
request_url=url,
|
request_url=url,
|
||||||
response_status_code=resp.status,
|
response_status_code=resp.status,
|
||||||
response_headers=resp_headers,
|
response_headers=dict(resp.headers),
|
||||||
response_content=bytes_payload,
|
response_content=bytes_payload,
|
||||||
)
|
)
|
||||||
return bytes_payload
|
return bytes_payload
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import comfy.model_management
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.ldm.flux.math
|
|
||||||
|
|
||||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -232,68 +231,6 @@ class Flux2Scheduler(io.ComfyNode):
|
|||||||
sigmas = get_schedule(steps, round(seq_len))
|
sigmas = get_schedule(steps, round(seq_len))
|
||||||
return io.NodeOutput(sigmas)
|
return io.NodeOutput(sigmas)
|
||||||
|
|
||||||
class KV_Attn_Input:
|
|
||||||
def __init__(self):
|
|
||||||
self.cache = {}
|
|
||||||
|
|
||||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
|
||||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
|
||||||
if len(reference_image_num_tokens) == 0:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
ref_toks = sum(reference_image_num_tokens)
|
|
||||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
|
||||||
if cache_key in self.cache:
|
|
||||||
kk, vv = self.cache[cache_key]
|
|
||||||
self.set_cache = False
|
|
||||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
|
||||||
|
|
||||||
self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
|
|
||||||
self.set_cache = True
|
|
||||||
return {"q": q, "k": k, "v": v}
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
self.cache = {}
|
|
||||||
|
|
||||||
|
|
||||||
class FluxKVCache(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> io.Schema:
|
|
||||||
return io.Schema(
|
|
||||||
node_id="FluxKVCache",
|
|
||||||
display_name="Flux KV Cache",
|
|
||||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
|
||||||
category="",
|
|
||||||
is_experimental=True,
|
|
||||||
inputs=[
|
|
||||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
|
||||||
m = model.clone()
|
|
||||||
input_patch_obj = KV_Attn_Input()
|
|
||||||
|
|
||||||
def model_input_patch(inputs):
|
|
||||||
if len(input_patch_obj.cache) > 0:
|
|
||||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
|
||||||
if ref_image_tokens > 0:
|
|
||||||
img = inputs["img"]
|
|
||||||
inputs["img"] = img[:, :-ref_image_tokens]
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
m.set_model_attn1_patch(input_patch_obj)
|
|
||||||
m.set_model_post_input_patch(model_input_patch)
|
|
||||||
if hasattr(model.model.diffusion_model, "params"):
|
|
||||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
|
||||||
else:
|
|
||||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
|
||||||
|
|
||||||
return io.NodeOutput(m)
|
|
||||||
|
|
||||||
class FluxExtension(ComfyExtension):
|
class FluxExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@@ -306,7 +243,6 @@ class FluxExtension(ComfyExtension):
|
|||||||
FluxKontextMultiReferenceLatentMethod,
|
FluxKontextMultiReferenceLatentMethod,
|
||||||
EmptyFlux2LatentImage,
|
EmptyFlux2LatentImage,
|
||||||
Flux2Scheduler,
|
Flux2Scheduler,
|
||||||
FluxKVCache,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,146 +10,198 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
class ModelMergeSimple:
|
|
||||||
|
class ModelMergeSimple(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model1": ("MODEL",),
|
return io.Schema(
|
||||||
"model2": ("MODEL",),
|
node_id="ModelMergeSimple",
|
||||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="advanced/model_merging",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Model.Input("model1"),
|
||||||
FUNCTION = "merge"
|
io.Model.Input("model2"),
|
||||||
|
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, model1, model2, ratio) -> io.NodeOutput:
|
||||||
def merge(self, model1, model2, ratio):
|
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
kp = model2.get_key_patches("diffusion_model.")
|
kp = model2.get_key_patches("diffusion_model.")
|
||||||
for k in kp:
|
for k in kp:
|
||||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class ModelSubtract:
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSubtract(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model1": ("MODEL",),
|
return io.Schema(
|
||||||
"model2": ("MODEL",),
|
node_id="ModelMergeSubtract",
|
||||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
category="advanced/model_merging",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Model.Input("model1"),
|
||||||
FUNCTION = "merge"
|
io.Model.Input("model2"),
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, model1, model2, multiplier) -> io.NodeOutput:
|
||||||
def merge(self, model1, model2, multiplier):
|
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
kp = model2.get_key_patches("diffusion_model.")
|
kp = model2.get_key_patches("diffusion_model.")
|
||||||
for k in kp:
|
for k in kp:
|
||||||
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class ModelAdd:
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ModelAdd(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model1": ("MODEL",),
|
return io.Schema(
|
||||||
"model2": ("MODEL",),
|
node_id="ModelMergeAdd",
|
||||||
}}
|
category="advanced/model_merging",
|
||||||
RETURN_TYPES = ("MODEL",)
|
inputs=[
|
||||||
FUNCTION = "merge"
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, model1, model2) -> io.NodeOutput:
|
||||||
def merge(self, model1, model2):
|
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
kp = model2.get_key_patches("diffusion_model.")
|
kp = model2.get_key_patches("diffusion_model.")
|
||||||
for k in kp:
|
for k in kp:
|
||||||
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class CLIPMergeSimple:
|
class CLIPMergeSimple(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip1": ("CLIP",),
|
return io.Schema(
|
||||||
"clip2": ("CLIP",),
|
node_id="CLIPMergeSimple",
|
||||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="advanced/model_merging",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("CLIP",)
|
io.Clip.Input("clip1"),
|
||||||
FUNCTION = "merge"
|
io.Clip.Input("clip2"),
|
||||||
|
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, clip1, clip2, ratio) -> io.NodeOutput:
|
||||||
def merge(self, clip1, clip2, ratio):
|
|
||||||
m = clip1.clone()
|
m = clip1.clone()
|
||||||
kp = clip2.get_key_patches()
|
kp = clip2.get_key_patches()
|
||||||
for k in kp:
|
for k in kp:
|
||||||
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||||
continue
|
continue
|
||||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class CLIPSubtract:
|
class CLIPSubtract(io.ComfyNode):
|
||||||
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip1": ("CLIP",),
|
return io.Schema(
|
||||||
"clip2": ("CLIP",),
|
node_id="CLIPMergeSubtract",
|
||||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
search_aliases=["clip difference", "text encoder subtract"],
|
||||||
}}
|
category="advanced/model_merging",
|
||||||
RETURN_TYPES = ("CLIP",)
|
inputs=[
|
||||||
FUNCTION = "merge"
|
io.Clip.Input("clip1"),
|
||||||
|
io.Clip.Input("clip2"),
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput:
|
||||||
def merge(self, clip1, clip2, multiplier):
|
|
||||||
m = clip1.clone()
|
m = clip1.clone()
|
||||||
kp = clip2.get_key_patches()
|
kp = clip2.get_key_patches()
|
||||||
for k in kp:
|
for k in kp:
|
||||||
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||||
continue
|
continue
|
||||||
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class CLIPAdd:
|
class CLIPAdd(io.ComfyNode):
|
||||||
SEARCH_ALIASES = ["combine clip"]
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip1": ("CLIP",),
|
return io.Schema(
|
||||||
"clip2": ("CLIP",),
|
node_id="CLIPMergeAdd",
|
||||||
}}
|
search_aliases=["combine clip"],
|
||||||
RETURN_TYPES = ("CLIP",)
|
category="advanced/model_merging",
|
||||||
FUNCTION = "merge"
|
inputs=[
|
||||||
|
io.Clip.Input("clip1"),
|
||||||
|
io.Clip.Input("clip2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, clip1, clip2) -> io.NodeOutput:
|
||||||
def merge(self, clip1, clip2):
|
|
||||||
m = clip1.clone()
|
m = clip1.clone()
|
||||||
kp = clip2.get_key_patches()
|
kp = clip2.get_key_patches()
|
||||||
for k in kp:
|
for k in kp:
|
||||||
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||||
continue
|
continue
|
||||||
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ModelMergeBlocks:
|
class ModelMergeBlocks(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model1": ("MODEL",),
|
return io.Schema(
|
||||||
"model2": ("MODEL",),
|
node_id="ModelMergeBlocks",
|
||||||
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="advanced/model_merging",
|
||||||
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
io.Model.Input("model1"),
|
||||||
}}
|
io.Model.Input("model2"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
FUNCTION = "merge"
|
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
@classmethod
|
||||||
|
def execute(cls, model1, model2, **kwargs) -> io.NodeOutput:
|
||||||
def merge(self, model1, model2, **kwargs):
|
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
kp = model2.get_key_patches("diffusion_model.")
|
kp = model2.get_key_patches("diffusion_model.")
|
||||||
default_ratio = next(iter(kwargs.values()))
|
default_ratio = next(iter(kwargs.values()))
|
||||||
@@ -165,7 +217,10 @@ class ModelMergeBlocks:
|
|||||||
last_arg_size = len(arg)
|
last_arg_size = len(arg)
|
||||||
|
|
||||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
|
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
|
||||||
@@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
|
|
||||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||||
|
|
||||||
class CheckpointSave:
|
|
||||||
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
|
class CheckpointSave(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CheckpointSave",
|
||||||
|
display_name="Save Checkpoint",
|
||||||
|
search_aliases=["save model", "export checkpoint", "merge save"],
|
||||||
|
category="advanced/model_merging",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput:
|
||||||
return {"required": { "model": ("MODEL",),
|
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
|
||||||
"clip": ("CLIP",),
|
return io.NodeOutput()
|
||||||
"vae": ("VAE",),
|
|
||||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
save = execute # TODO: remove
|
||||||
|
|
||||||
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
||||||
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
class CLIPSave:
|
class CLIPSave(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CLIPSave",
|
||||||
|
category="advanced/model_merging",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.String.Input("filename_prefix", default="clip/ComfyUI"),
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, clip, filename_prefix) -> io.NodeOutput:
|
||||||
return {"required": { "clip": ("CLIP",),
|
|
||||||
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
|
||||||
|
|
||||||
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
||||||
prompt_info = ""
|
prompt_info = ""
|
||||||
if prompt is not None:
|
if cls.hidden.prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(cls.hidden.prompt)
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
metadata["format"] = "pt"
|
metadata["format"] = "pt"
|
||||||
metadata["prompt"] = prompt_info
|
metadata["prompt"] = prompt_info
|
||||||
if extra_pnginfo is not None:
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
for x in extra_pnginfo:
|
for x in cls.hidden.extra_pnginfo:
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||||
|
|
||||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
|
output_dir = folder_paths.get_output_directory()
|
||||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||||
current_clip_sd = {}
|
current_clip_sd = {}
|
||||||
@@ -295,7 +356,7 @@ class CLIPSave:
|
|||||||
replace_prefix[prefix] = ""
|
replace_prefix[prefix] = ""
|
||||||
replace_prefix["transformer."] = ""
|
replace_prefix["transformer."] = ""
|
||||||
|
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir)
|
||||||
|
|
||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
@@ -303,76 +364,88 @@ class CLIPSave:
|
|||||||
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||||
|
|
||||||
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
class VAESave:
|
save = execute # TODO: remove
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
|
class VAESave(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="VAESave",
|
||||||
|
category="advanced/model_merging",
|
||||||
|
inputs=[
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.String.Input("filename_prefix", default="vae/ComfyUI_vae"),
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, vae, filename_prefix) -> io.NodeOutput:
|
||||||
return {"required": { "vae": ("VAE",),
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
|
||||||
|
|
||||||
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
||||||
prompt_info = ""
|
prompt_info = ""
|
||||||
if prompt is not None:
|
if cls.hidden.prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(cls.hidden.prompt)
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
metadata["prompt"] = prompt_info
|
metadata["prompt"] = prompt_info
|
||||||
if extra_pnginfo is not None:
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
for x in extra_pnginfo:
|
for x in cls.hidden.extra_pnginfo:
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||||
|
|
||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
class ModelSave:
|
save = execute # TODO: remove
|
||||||
SEARCH_ALIASES = ["export model", "checkpoint save"]
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
class ModelSave(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelSave",
|
||||||
|
search_aliases=["export model", "checkpoint save"],
|
||||||
|
category="advanced/model_merging",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"),
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, model, filename_prefix) -> io.NodeOutput:
|
||||||
return {"required": { "model": ("MODEL",),
|
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
|
||||||
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
|
return io.NodeOutput()
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/model_merging"
|
save = execute # TODO: remove
|
||||||
|
|
||||||
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
||||||
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class ModelMergingExtension(ComfyExtension):
|
||||||
"ModelMergeSimple": ModelMergeSimple,
|
@override
|
||||||
"ModelMergeBlocks": ModelMergeBlocks,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"ModelMergeSubtract": ModelSubtract,
|
return [
|
||||||
"ModelMergeAdd": ModelAdd,
|
ModelMergeSimple,
|
||||||
"CheckpointSave": CheckpointSave,
|
ModelMergeBlocks,
|
||||||
"CLIPMergeSimple": CLIPMergeSimple,
|
ModelSubtract,
|
||||||
"CLIPMergeSubtract": CLIPSubtract,
|
ModelAdd,
|
||||||
"CLIPMergeAdd": CLIPAdd,
|
CheckpointSave,
|
||||||
"CLIPSave": CLIPSave,
|
CLIPMergeSimple,
|
||||||
"VAESave": VAESave,
|
CLIPSubtract,
|
||||||
"ModelSave": ModelSave,
|
CLIPAdd,
|
||||||
}
|
CLIPSave,
|
||||||
|
VAESave,
|
||||||
|
ModelSave,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"CheckpointSave": "Save Checkpoint",
|
async def comfy_entrypoint() -> ModelMergingExtension:
|
||||||
}
|
return ModelMergingExtension()
|
||||||
|
|||||||
@@ -1,356 +1,455 @@
|
|||||||
import comfy_extras.nodes_model_merging
|
import comfy_extras.nodes_model_merging
|
||||||
|
|
||||||
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["time_embed."] = argument
|
inputs.append(io.Float.Input("time_embed.", **argument))
|
||||||
arg_dict["label_emb."] = argument
|
inputs.append(io.Float.Input("label_emb.", **argument))
|
||||||
|
|
||||||
for i in range(12):
|
for i in range(12):
|
||||||
arg_dict["input_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
arg_dict["middle_block.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument))
|
||||||
|
|
||||||
for i in range(12):
|
for i in range(12):
|
||||||
arg_dict["output_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["out."] = argument
|
inputs.append(io.Float.Input("out.", **argument))
|
||||||
|
|
||||||
return {"required": arg_dict}
|
return io.Schema(
|
||||||
|
node_id="ModelMergeSD1",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMergeSD2(ModelMergeSD1):
|
||||||
|
# SD1 and SD2 have the same blocks
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
schema = ModelMergeSD1.define_schema()
|
||||||
|
schema.node_id = "ModelMergeSD2"
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["time_embed."] = argument
|
inputs.append(io.Float.Input("time_embed.", **argument))
|
||||||
arg_dict["label_emb."] = argument
|
inputs.append(io.Float.Input("label_emb.", **argument))
|
||||||
|
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
arg_dict["input_blocks.{}".format(i)] = argument
|
inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument))
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
arg_dict["middle_block.{}".format(i)] = argument
|
inputs.append(io.Float.Input("middle_block.{}".format(i), **argument))
|
||||||
|
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
arg_dict["output_blocks.{}".format(i)] = argument
|
inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["out."] = argument
|
inputs.append(io.Float.Input("out.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeSDXL",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embed."] = argument
|
inputs.append(io.Float.Input("pos_embed.", **argument))
|
||||||
arg_dict["x_embedder."] = argument
|
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||||
arg_dict["context_embedder."] = argument
|
inputs.append(io.Float.Input("context_embedder.", **argument))
|
||||||
arg_dict["y_embedder."] = argument
|
inputs.append(io.Float.Input("y_embedder.", **argument))
|
||||||
arg_dict["t_embedder."] = argument
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
|
||||||
for i in range(24):
|
for i in range(24):
|
||||||
arg_dict["joint_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
return {"required": arg_dict}
|
return io.Schema(
|
||||||
|
node_id="ModelMergeSD3_2B",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["init_x_linear."] = argument
|
inputs.append(io.Float.Input("init_x_linear.", **argument))
|
||||||
arg_dict["positional_encoding"] = argument
|
inputs.append(io.Float.Input("positional_encoding", **argument))
|
||||||
arg_dict["cond_seq_linear."] = argument
|
inputs.append(io.Float.Input("cond_seq_linear.", **argument))
|
||||||
arg_dict["register_tokens"] = argument
|
inputs.append(io.Float.Input("register_tokens", **argument))
|
||||||
arg_dict["t_embedder."] = argument
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
arg_dict["double_layers.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument))
|
||||||
|
|
||||||
for i in range(32):
|
for i in range(32):
|
||||||
arg_dict["single_layers.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["modF."] = argument
|
inputs.append(io.Float.Input("modF.", **argument))
|
||||||
arg_dict["final_linear."] = argument
|
inputs.append(io.Float.Input("final_linear.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeAuraflow",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["img_in."] = argument
|
inputs.append(io.Float.Input("img_in.", **argument))
|
||||||
arg_dict["time_in."] = argument
|
inputs.append(io.Float.Input("time_in.", **argument))
|
||||||
arg_dict["guidance_in"] = argument
|
inputs.append(io.Float.Input("guidance_in", **argument))
|
||||||
arg_dict["vector_in."] = argument
|
inputs.append(io.Float.Input("vector_in.", **argument))
|
||||||
arg_dict["txt_in."] = argument
|
inputs.append(io.Float.Input("txt_in.", **argument))
|
||||||
|
|
||||||
for i in range(19):
|
for i in range(19):
|
||||||
arg_dict["double_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
for i in range(38):
|
for i in range(38):
|
||||||
arg_dict["single_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeFlux1",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embed."] = argument
|
inputs.append(io.Float.Input("pos_embed.", **argument))
|
||||||
arg_dict["x_embedder."] = argument
|
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||||
arg_dict["context_embedder."] = argument
|
inputs.append(io.Float.Input("context_embedder.", **argument))
|
||||||
arg_dict["y_embedder."] = argument
|
inputs.append(io.Float.Input("y_embedder.", **argument))
|
||||||
arg_dict["t_embedder."] = argument
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
|
||||||
for i in range(38):
|
for i in range(38):
|
||||||
arg_dict["joint_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeSD35_Large",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_frequencies."] = argument
|
inputs.append(io.Float.Input("pos_frequencies.", **argument))
|
||||||
arg_dict["t_embedder."] = argument
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
arg_dict["t5_y_embedder."] = argument
|
inputs.append(io.Float.Input("t5_y_embedder.", **argument))
|
||||||
arg_dict["t5_yproj."] = argument
|
inputs.append(io.Float.Input("t5_yproj.", **argument))
|
||||||
|
|
||||||
for i in range(48):
|
for i in range(48):
|
||||||
arg_dict["blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeMochiPreview",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["patchify_proj."] = argument
|
inputs.append(io.Float.Input("patchify_proj.", **argument))
|
||||||
arg_dict["adaln_single."] = argument
|
inputs.append(io.Float.Input("adaln_single.", **argument))
|
||||||
arg_dict["caption_projection."] = argument
|
inputs.append(io.Float.Input("caption_projection.", **argument))
|
||||||
|
|
||||||
for i in range(28):
|
for i in range(28):
|
||||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["scale_shift_table"] = argument
|
inputs.append(io.Float.Input("scale_shift_table", **argument))
|
||||||
arg_dict["proj_out."] = argument
|
inputs.append(io.Float.Input("proj_out.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeLTXV",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embedder."] = argument
|
|
||||||
arg_dict["extra_pos_embedder."] = argument
|
|
||||||
arg_dict["x_embedder."] = argument
|
|
||||||
arg_dict["t_embedder."] = argument
|
|
||||||
arg_dict["affline_norm."] = argument
|
|
||||||
|
|
||||||
|
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("affline_norm.", **argument))
|
||||||
|
|
||||||
for i in range(28):
|
for i in range(28):
|
||||||
arg_dict["blocks.block{}.".format(i)] = argument
|
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeCosmos7B",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embedder."] = argument
|
|
||||||
arg_dict["extra_pos_embedder."] = argument
|
|
||||||
arg_dict["x_embedder."] = argument
|
|
||||||
arg_dict["t_embedder."] = argument
|
|
||||||
arg_dict["affline_norm."] = argument
|
|
||||||
|
|
||||||
|
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("affline_norm.", **argument))
|
||||||
|
|
||||||
for i in range(36):
|
for i in range(36):
|
||||||
arg_dict["blocks.block{}.".format(i)] = argument
|
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeCosmos14B",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["patch_embedding."] = argument
|
inputs.append(io.Float.Input("patch_embedding.", **argument))
|
||||||
arg_dict["time_embedding."] = argument
|
inputs.append(io.Float.Input("time_embedding.", **argument))
|
||||||
arg_dict["time_projection."] = argument
|
inputs.append(io.Float.Input("time_projection.", **argument))
|
||||||
arg_dict["text_embedding."] = argument
|
inputs.append(io.Float.Input("text_embedding.", **argument))
|
||||||
arg_dict["img_emb."] = argument
|
inputs.append(io.Float.Input("img_emb.", **argument))
|
||||||
|
|
||||||
for i in range(40):
|
for i in range(40):
|
||||||
arg_dict["blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["head."] = argument
|
inputs.append(io.Float.Input("head.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeWAN2_1",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embedder."] = argument
|
|
||||||
arg_dict["x_embedder."] = argument
|
|
||||||
arg_dict["t_embedder."] = argument
|
|
||||||
arg_dict["t_embedding_norm."] = argument
|
|
||||||
|
|
||||||
|
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
|
||||||
|
|
||||||
for i in range(28):
|
for i in range(28):
|
||||||
arg_dict["blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeCosmosPredict2_2B",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embedder."] = argument
|
|
||||||
arg_dict["x_embedder."] = argument
|
|
||||||
arg_dict["t_embedder."] = argument
|
|
||||||
arg_dict["t_embedding_norm."] = argument
|
|
||||||
|
|
||||||
|
inputs.append(io.Float.Input("pos_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("x_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("t_embedder.", **argument))
|
||||||
|
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
|
||||||
|
|
||||||
for i in range(36):
|
for i in range(36):
|
||||||
arg_dict["blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["final_layer."] = argument
|
inputs.append(io.Float.Input("final_layer.", **argument))
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ModelMergeCosmosPredict2_14B",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": arg_dict}
|
|
||||||
|
|
||||||
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
CATEGORY = "advanced/model_merging/model_specific"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
arg_dict = { "model1": ("MODEL",),
|
inputs = [
|
||||||
"model2": ("MODEL",)}
|
io.Model.Input("model1"),
|
||||||
|
io.Model.Input("model2"),
|
||||||
|
]
|
||||||
|
|
||||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
|
||||||
|
|
||||||
arg_dict["pos_embeds."] = argument
|
inputs.append(io.Float.Input("pos_embeds.", **argument))
|
||||||
arg_dict["img_in."] = argument
|
inputs.append(io.Float.Input("img_in.", **argument))
|
||||||
arg_dict["txt_norm."] = argument
|
inputs.append(io.Float.Input("txt_norm.", **argument))
|
||||||
arg_dict["txt_in."] = argument
|
inputs.append(io.Float.Input("txt_in.", **argument))
|
||||||
arg_dict["time_text_embed."] = argument
|
inputs.append(io.Float.Input("time_text_embed.", **argument))
|
||||||
|
|
||||||
for i in range(60):
|
for i in range(60):
|
||||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
|
||||||
|
|
||||||
arg_dict["proj_out."] = argument
|
inputs.append(io.Float.Input("proj_out.", **argument))
|
||||||
|
|
||||||
return {"required": arg_dict}
|
return io.Schema(
|
||||||
|
node_id="ModelMergeQwenImage",
|
||||||
|
category="advanced/model_merging/model_specific",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
class ModelMergingModelSpecificExtension(ComfyExtension):
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
@override
|
||||||
"ModelMergeSDXL": ModelMergeSDXL,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
return [
|
||||||
"ModelMergeAuraflow": ModelMergeAuraflow,
|
ModelMergeSD1,
|
||||||
"ModelMergeFlux1": ModelMergeFlux1,
|
ModelMergeSD2,
|
||||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
ModelMergeSDXL,
|
||||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
ModelMergeSD3_2B,
|
||||||
"ModelMergeLTXV": ModelMergeLTXV,
|
ModelMergeAuraflow,
|
||||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
ModelMergeFlux1,
|
||||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
ModelMergeSD35_Large,
|
||||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
ModelMergeMochiPreview,
|
||||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
ModelMergeLTXV,
|
||||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
ModelMergeCosmos7B,
|
||||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
ModelMergeCosmos14B,
|
||||||
}
|
ModelMergeWAN2_1,
|
||||||
|
ModelMergeCosmosPredict2_2B,
|
||||||
|
ModelMergeCosmosPredict2_14B,
|
||||||
|
ModelMergeQwenImage,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ModelMergingModelSpecificExtension:
|
||||||
|
return ModelMergingModelSpecificExtension()
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
import node_helpers
|
|
||||||
from comfy_api.latest import ComfyExtension, io, UI
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
|
|
||||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
|
||||||
hex_color = hex_color.lstrip("#")
|
|
||||||
if len(hex_color) != 6:
|
|
||||||
return (0.0, 0.0, 0.0)
|
|
||||||
r = int(hex_color[0:2], 16) / 255.0
|
|
||||||
g = int(hex_color[2:4], 16) / 255.0
|
|
||||||
b = int(hex_color[4:6], 16) / 255.0
|
|
||||||
return (r, g, b)
|
|
||||||
|
|
||||||
|
|
||||||
class PainterNode(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="Painter",
|
|
||||||
display_name="Painter",
|
|
||||||
category="image",
|
|
||||||
inputs=[
|
|
||||||
io.Image.Input(
|
|
||||||
"image",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Optional base image to paint over",
|
|
||||||
),
|
|
||||||
io.String.Input(
|
|
||||||
"mask",
|
|
||||||
default="",
|
|
||||||
socketless=True,
|
|
||||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
|
||||||
),
|
|
||||||
io.Int.Input(
|
|
||||||
"width",
|
|
||||||
default=512,
|
|
||||||
min=64,
|
|
||||||
max=4096,
|
|
||||||
step=64,
|
|
||||||
socketless=True,
|
|
||||||
extra_dict={"hidden": True},
|
|
||||||
),
|
|
||||||
io.Int.Input(
|
|
||||||
"height",
|
|
||||||
default=512,
|
|
||||||
min=64,
|
|
||||||
max=4096,
|
|
||||||
step=64,
|
|
||||||
socketless=True,
|
|
||||||
extra_dict={"hidden": True},
|
|
||||||
),
|
|
||||||
io.Color.Input("bg_color", default="#000000"),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Image.Output("IMAGE"),
|
|
||||||
io.Mask.Output("MASK"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
|
||||||
if image is not None:
|
|
||||||
base_image = image[:1]
|
|
||||||
h, w = base_image.shape[1], base_image.shape[2]
|
|
||||||
else:
|
|
||||||
h, w = height, width
|
|
||||||
r, g, b = hex_to_rgb(bg_color)
|
|
||||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
|
||||||
base_image[0, :, :, 0] = r
|
|
||||||
base_image[0, :, :, 1] = g
|
|
||||||
base_image[0, :, :, 2] = b
|
|
||||||
|
|
||||||
if mask and mask.strip():
|
|
||||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
|
||||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
|
||||||
painter_img = painter_img.convert("RGBA")
|
|
||||||
|
|
||||||
if painter_img.size != (w, h):
|
|
||||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
|
||||||
|
|
||||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
|
||||||
painter_rgb = painter_np[:, :, :3]
|
|
||||||
painter_alpha = painter_np[:, :, 3:4]
|
|
||||||
|
|
||||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
|
||||||
|
|
||||||
base_np = base_image[0].cpu().numpy()
|
|
||||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
|
||||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
|
||||||
else:
|
|
||||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
|
||||||
out_image = base_image
|
|
||||||
|
|
||||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
|
||||||
if mask and mask.strip():
|
|
||||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
|
||||||
if os.path.exists(mask_path):
|
|
||||||
m = hashlib.sha256()
|
|
||||||
with open(mask_path, "rb") as f:
|
|
||||||
m.update(f.read())
|
|
||||||
return m.digest().hex()
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PainterExtension(ComfyExtension):
|
|
||||||
@override
|
|
||||||
async def get_node_list(self):
|
|
||||||
return [PainterNode]
|
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint():
|
|
||||||
return PainterExtension()
|
|
||||||
@@ -6,44 +6,62 @@ import folder_paths
|
|||||||
import comfy_extras.nodes_model_merging
|
import comfy_extras.nodes_model_merging
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
|
||||||
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
class ImageOnlyCheckpointLoader:
|
|
||||||
|
class ImageOnlyCheckpointLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
return io.Schema(
|
||||||
}}
|
node_id="ImageOnlyCheckpointLoader",
|
||||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
display_name="Image Only Checkpoint Loader (img2vid model)",
|
||||||
FUNCTION = "load_checkpoint"
|
category="loaders/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
io.ClipVision.Output(),
|
||||||
|
io.Vae.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders/video_models"
|
@classmethod
|
||||||
|
def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput:
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
|
||||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return io.NodeOutput(out[0], out[3], out[2])
|
||||||
|
|
||||||
|
load_checkpoint = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class SVD_img2vid_Conditioning:
|
class SVD_img2vid_Conditioning(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
return io.Schema(
|
||||||
"init_image": ("IMAGE",),
|
node_id="SVD_img2vid_Conditioning",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
io.ClipVision.Input("clip_vision"),
|
||||||
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
|
io.Image.Input("init_image"),
|
||||||
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023, "advanced": True}),
|
io.Vae.Input("vae"),
|
||||||
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "advanced": True})
|
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
}}
|
io.Int.Input("video_frames", default=14, min=1, max=4096),
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023, advanced=True),
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
io.Int.Input("fps", default=6, min=1, max=1024),
|
||||||
|
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "encode"
|
@classmethod
|
||||||
|
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput:
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
|
|
||||||
output = clip_vision.encode_image(init_image)
|
output = clip_vision.encode_image(init_image)
|
||||||
pooled = output.image_embeds.unsqueeze(0)
|
pooled = output.image_embeds.unsqueeze(0)
|
||||||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||||
@@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning:
|
|||||||
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
|
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
|
||||||
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
|
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
|
||||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||||
return (positive, negative, {"samples":latent})
|
return io.NodeOutput(positive, negative, {"samples":latent})
|
||||||
|
|
||||||
class VideoLinearCFGGuidance:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLinearCFGGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
|
node_id="VideoLinearCFGGuidance",
|
||||||
}}
|
category="sampling/video_models",
|
||||||
RETURN_TYPES = ("MODEL",)
|
inputs=[
|
||||||
FUNCTION = "patch"
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "sampling/video_models"
|
@classmethod
|
||||||
|
def execute(cls, model, min_cfg) -> io.NodeOutput:
|
||||||
def patch(self, model, min_cfg):
|
|
||||||
def linear_cfg(args):
|
def linear_cfg(args):
|
||||||
cond = args["cond"]
|
cond = args["cond"]
|
||||||
uncond = args["uncond"]
|
uncond = args["uncond"]
|
||||||
@@ -78,20 +104,28 @@ class VideoLinearCFGGuidance:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_cfg_function(linear_cfg)
|
m.set_model_sampler_cfg_function(linear_cfg)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class VideoTriangleCFGGuidance:
|
patch = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTriangleCFGGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
|
node_id="VideoTriangleCFGGuidance",
|
||||||
}}
|
category="sampling/video_models",
|
||||||
RETURN_TYPES = ("MODEL",)
|
inputs=[
|
||||||
FUNCTION = "patch"
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "sampling/video_models"
|
@classmethod
|
||||||
|
def execute(cls, model, min_cfg) -> io.NodeOutput:
|
||||||
def patch(self, model, min_cfg):
|
|
||||||
def linear_cfg(args):
|
def linear_cfg(args):
|
||||||
cond = args["cond"]
|
cond = args["cond"]
|
||||||
uncond = args["uncond"]
|
uncond = args["uncond"]
|
||||||
@@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_cfg_function(linear_cfg)
|
m.set_model_sampler_cfg_function(linear_cfg)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
patch = execute # TODO: remove
|
||||||
CATEGORY = "advanced/model_merging"
|
|
||||||
|
|
||||||
|
class ImageOnlyCheckpointSave(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ImageOnlyCheckpointSave",
|
||||||
|
search_aliases=["save model", "export checkpoint", "merge save"],
|
||||||
|
category="advanced/model_merging",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ClipVision.Input("clip_vision"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput:
|
||||||
return {"required": { "model": ("MODEL",),
|
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
|
||||||
"clip_vision": ("CLIP_VISION",),
|
return io.NodeOutput()
|
||||||
"vae": ("VAE",),
|
|
||||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
||||||
|
|
||||||
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
save = execute # TODO: remove
|
||||||
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningSetAreaPercentageVideo:
|
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
node_id="ConditioningSetAreaPercentageVideo",
|
||||||
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
category="conditioning",
|
||||||
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
inputs=[
|
||||||
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
io.Conditioning.Input("conditioning"),
|
||||||
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
}}
|
io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||||
FUNCTION = "append"
|
io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
@classmethod
|
||||||
|
def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput:
|
||||||
def append(self, conditioning, width, height, temporal, x, y, z, strength):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
|
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
|
||||||
"strength": strength,
|
"strength": strength,
|
||||||
"set_area_to_bounds": False})
|
"set_area_to_bounds": False})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class VideoModelExtension(ComfyExtension):
|
||||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
@override
|
||||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
return [
|
||||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
ImageOnlyCheckpointLoader,
|
||||||
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
SVD_img2vid_Conditioning,
|
||||||
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
|
VideoLinearCFGGuidance,
|
||||||
}
|
VideoTriangleCFGGuidance,
|
||||||
|
ImageOnlyCheckpointSave,
|
||||||
|
ConditioningSetAreaPercentageVideo,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
|
async def comfy_entrypoint() -> VideoModelExtension:
|
||||||
}
|
return VideoModelExtension()
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@@ -2450,7 +2450,6 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
"nodes_math.py",
|
"nodes_math.py",
|
||||||
"nodes_painter.py",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.41.16
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.18
|
comfyui-workflow-templates==0.9.18
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
@@ -22,8 +22,8 @@ alembic
|
|||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.8
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.2.10
|
comfy-aimdo>=0.2.9
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
Reference in New Issue
Block a user