Compare commits

..

5 Commits

Author SHA1 Message Date
Terry Jia
040e2199ce refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema 2026-03-17 20:12:57 -04:00
Terry Jia
e92aec7f88 linear curve 2026-03-17 20:12:56 -04:00
Christian Byrne
2d54a52b46 feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986)
CurveInput is an abstract base class so future curve representations
(bezier, LUT-based, analytical functions) can be added without breaking
downstream nodes that type-check against CurveInput.

MonotoneCubicCurve is the concrete implementation that:
- Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly
- Pre-computes slopes as numpy arrays at construction time
- Provides vectorised interp_array() using numpy for batch evaluation
- interp() for single-value evaluation
- to_lut() for generating lookup tables

CurveEditor node wraps raw widget points in MonotoneCubicCurve.
2026-03-17 20:12:55 -04:00
Terry Jia
4e1952ec17 remove curve to sigmas node 2026-03-17 20:12:54 -04:00
Terry Jia
b279221c29 CURVE node 2026-03-17 20:12:54 -04:00
15 changed files with 273 additions and 61 deletions

View File

@@ -136,7 +136,16 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

View File

@@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:

View File

@@ -455,7 +455,7 @@ class VAE:
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
@@ -952,8 +952,8 @@ class VAE:
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out

View File

@@ -5,6 +5,9 @@ from comfy_api.latest._input import (
MaskInput,
LatentInput,
VideoInput,
CurveInput,
MonotoneCubicCurve,
LinearCurve,
)
__all__ = [
@@ -13,4 +16,7 @@ __all__ = [
"MaskInput",
"LatentInput",
"VideoInput",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -1,4 +1,4 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput, CurveInput, MonotoneCubicCurve, LinearCurve
from .video_types import VideoInput
__all__ = [
@@ -7,4 +7,7 @@ __all__ = [
"VideoInput",
"MaskInput",
"LatentInput",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -1,3 +1,8 @@
from __future__ import annotations
import math
from abc import ABC, abstractmethod
import numpy as np
import torch
from typing import TypedDict, Optional
@@ -40,3 +45,190 @@ class LatentInput(TypedDict):
"""
batch_index: Optional[list[int]]
CurvePoint = tuple[float, float]
class CurveInput(ABC):
"""Abstract base class for curve inputs.
Subclasses represent different curve representations (control-point
interpolation, analytical functions, LUT-based, etc.) while exposing a
uniform evaluation interface to downstream nodes.
"""
@property
@abstractmethod
def points(self) -> list[CurvePoint]:
"""The control points that define this curve."""
@abstractmethod
def interp(self, x: float) -> float:
"""Evaluate the curve at a single *x* value in [0, 1]."""
def interp_array(self, xs: np.ndarray) -> np.ndarray:
"""Vectorised evaluation over a numpy array of x values.
Subclasses should override this for better performance. The default
falls back to scalar ``interp`` calls.
"""
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
return self.interp_array(np.linspace(0.0, 1.0, size))
class MonotoneCubicCurve(CurveInput):
"""Monotone cubic Hermite interpolation over control points.
Mirrors the frontend ``createMonotoneInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
backend evaluation matches the editor preview exactly.
All heavy work (sorting, slope computation) happens once at construction.
``interp_array`` is fully vectorised with numpy.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
self._slopes = self._compute_slopes()
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def _compute_slopes(self) -> np.ndarray:
xs, ys = self._xs, self._ys
n = len(xs)
if n < 2:
return np.zeros(n, dtype=np.float64)
dx = np.diff(xs)
dy = np.diff(ys)
dx_safe = np.where(dx == 0, 1.0, dx)
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
slopes = np.empty(n, dtype=np.float64)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0.0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0.0
slopes[i + 1] = 0.0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha * alpha + beta * beta
if s > 9:
t = 3 / math.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
return slopes
def interp(self, x: float) -> float:
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
if x <= xs[0]:
return float(ys[0])
if x >= xs[-1]:
return float(ys[-1])
hi = int(np.searchsorted(xs, x, side='right'))
hi = min(hi, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
if dx == 0:
return float(ys[lo])
t = (x - xs[lo]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
"""Fully vectorised evaluation using numpy."""
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if n == 1:
return np.full_like(xs_in, ys[0], dtype=np.float64)
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
dx_safe = np.where(dx == 0, 1.0, dx)
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
result = np.where(xs_in <= xs[0], ys[0], result)
result = np.where(xs_in >= xs[-1], ys[-1], result)
return result
def __repr__(self) -> str:
return f"MonotoneCubicCurve(points={self._points})"
class LinearCurve(CurveInput):
"""Piecewise linear interpolation over control points.
Mirrors the frontend ``createLinearInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def interp(self, x: float) -> float:
xs, ys = self._xs, self._ys
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
return float(np.interp(x, xs, ys))
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
if len(self._xs) == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if len(self._xs) == 1:
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
return np.interp(xs_in, self._xs, self._ys)
def __repr__(self) -> str:
return f"LinearCurve(points={self._points})"

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from comfy.samplers import CFGGuider, Sampler
from comfy.sd import CLIP, VAE
from comfy.sd import StyleModel as StyleModel_
from comfy_api.input import VideoInput
from comfy_api.input import VideoInput, CurveInput as CurveInput_
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
@@ -1243,7 +1243,8 @@ class BoundingBox(ComfyTypeIO):
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
if TYPE_CHECKING:
Type = CurveInput_
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
@@ -1252,6 +1253,12 @@ class Curve(ComfyTypeIO):
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
def as_dict(self):
d = super().as_dict()
if self.default is not None:
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
@@ -1353,7 +1360,6 @@ class NodeInfoV1:
python_module: Any=None
category: str=None
output_node: bool=None
has_intermediate_output: bool=None
deprecated: bool=None
experimental: bool=None
dev_only: bool=None
@@ -1466,16 +1472,6 @@ class Schema:
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
has_intermediate_output: bool=False
"""Flags this node as having intermediate output that should persist across page refreshes.
Nodes with this flag behave like output nodes (their UI results are cached and resent
to the frontend) but do NOT automatically get added to the execution list. This means
they will only execute if they are on the dependency path of a real output node.
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
is_deprecated: bool=False
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
is_experimental: bool=False
@@ -1593,7 +1589,6 @@ class Schema:
category=self.category,
description=self.description,
output_node=self.is_output_node,
has_intermediate_output=self.has_intermediate_output,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
@@ -1885,14 +1880,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_HAS_INTERMEDIATE_OUTPUT = None
@final
@classproperty
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls.GET_SCHEMA()
return cls._HAS_INTERMEDIATE_OUTPUT
_INPUT_IS_LIST = None
@final
@classproperty
@@ -1985,8 +1972,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._NOT_IDEMPOTENT is None:

View File

@@ -67,7 +67,6 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel):

View File

@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model;
$r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@@ -188,12 +188,10 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*")
for part in parts:
if (part.thought is True) != thought:
continue
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@@ -933,11 +931,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -999,11 +992,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@@ -118,13 +118,6 @@ class TopologicalSort:
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def is_intermediate_output(self, node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type)
if class_def is None:
return False
return hasattr(class_def, 'HAS_INTERMEDIATE_OUTPUT') and class_def.HAS_INTERMEDIATE_OUTPUT == True
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
@@ -136,7 +129,7 @@ class TopologicalSort:
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if not self.is_cached(from_node_id) or self.is_intermediate_output(from_node_id):
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
@@ -166,7 +159,7 @@ class TopologicalSort:
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy):
if not self.is_cached(from_node_id) or self.is_intermediate_output(from_node_id):
if not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
@@ -284,8 +277,6 @@ class ExecutionList(TopologicalSort):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
if hasattr(class_def, 'HAS_INTERMEDIATE_OUTPUT') and class_def.HAS_INTERMEDIATE_OUTPUT == True:
return True
return False
# If an available node is async, do that first.

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from comfy_api.latest import ComfyExtension, io
from comfy_api.input import CurveInput, MonotoneCubicCurve, LinearCurve
from typing_extensions import override
class CurveEditor(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CurveEditor",
display_name="Curve Editor",
category="utils",
inputs=[
io.Curve.Input("curve"),
],
outputs=[
io.Curve.Output("curve"),
],
)
@classmethod
def execute(cls, curve) -> io.NodeOutput:
if isinstance(curve, CurveInput):
return io.NodeOutput(curve)
raw_points = curve["points"] if isinstance(curve, dict) else curve
points = [(float(x), float(y)) for x, y in raw_points]
interpolation = curve.get("interpolation", "monotone_cubic") if isinstance(curve, dict) else "monotone_cubic"
if interpolation == "linear":
return io.NodeOutput(LinearCurve(points))
return io.NodeOutput(MonotoneCubicCurve(points))
class CurveExtension(ComfyExtension):
@override
async def get_node_list(self):
return [CurveEditor]
async def comfy_entrypoint():
return CurveExtension()

View File

@@ -59,7 +59,6 @@ class ImageCropV2(IO.ComfyNode):
display_name="Image Crop",
category="image/transform",
essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@@ -30,7 +30,6 @@ class PainterNode(io.ComfyNode):
node_id="Painter",
display_name="Painter",
category="image",
has_intermediate_output=True,
inputs=[
io.Image.Input(
"image",

View File

@@ -2453,6 +2453,7 @@ async def init_builtin_extra_nodes():
"nodes_sdpose.py",
"nodes_math.py",
"nodes_painter.py",
"nodes_curve.py",
]
import_failed = []

View File

@@ -709,9 +709,6 @@ class PromptServer():
else:
info['output_node'] = False
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
info['has_intermediate_output'] = True
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY