mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 06:39:56 +00:00
feat: add COLOR_CURVES type and ColorCurves node
This commit is contained in:
@@ -1252,6 +1252,32 @@ class Curve(ComfyTypeIO):
|
|||||||
return super().as_dict()
|
return super().as_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="COLOR_CURVES")
|
||||||
|
class ColorCurves(ComfyTypeIO):
|
||||||
|
class ColorCurvesDict(TypedDict):
|
||||||
|
rgb: list[list[float]]
|
||||||
|
red: list[list[float]]
|
||||||
|
green: list[list[float]]
|
||||||
|
blue: list[list[float]]
|
||||||
|
|
||||||
|
Type = ColorCurvesDict
|
||||||
|
|
||||||
|
class Input(WidgetInput):
|
||||||
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
|
socketless: bool=True, default: dict=None, advanced: bool=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||||
|
if default is None:
|
||||||
|
self.default = {
|
||||||
|
"rgb": [[0, 0], [1, 1]],
|
||||||
|
"red": [[0, 0], [1, 1]],
|
||||||
|
"green": [[0, 0], [1, 1]],
|
||||||
|
"blue": [[0, 0], [1, 1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict()
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||||
@@ -2239,5 +2265,6 @@ __all__ = [
|
|||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
"Curve",
|
"Curve",
|
||||||
|
"ColorCurves",
|
||||||
"NodeReplace",
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
137
comfy_extras/nodes_color_curves.py
Normal file
137
comfy_extras/nodes_color_curves.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
|
|
||||||
|
|
||||||
|
def _monotone_cubic_hermite(xs, ys, x_query):
|
||||||
|
"""Evaluate monotone cubic Hermite interpolation at x_query points."""
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(x_query)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(x_query, ys[0])
|
||||||
|
|
||||||
|
# Compute slopes
|
||||||
|
deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10)
|
||||||
|
|
||||||
|
# Compute tangents (Fritsch-Carlson)
|
||||||
|
slopes = np.zeros(n)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
# Enforce monotonicity
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0
|
||||||
|
slopes[i + 1] = 0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha ** 2 + beta ** 2
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / np.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
result = np.zeros_like(x_query, dtype=np.float64)
|
||||||
|
indices = np.searchsorted(xs, x_query, side='right') - 1
|
||||||
|
indices = np.clip(indices, 0, n - 2)
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
mask = indices == i
|
||||||
|
if not np.any(mask):
|
||||||
|
continue
|
||||||
|
dx = xs[i + 1] - xs[i]
|
||||||
|
if dx == 0:
|
||||||
|
result[mask] = ys[i]
|
||||||
|
continue
|
||||||
|
t = (x_query[mask] - xs[i]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1]
|
||||||
|
|
||||||
|
# Clamp edges
|
||||||
|
result[x_query <= xs[0]] = ys[0]
|
||||||
|
result[x_query >= xs[-1]] = ys[-1]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _build_lut(points):
|
||||||
|
"""Build a 256-entry LUT from curve control points in [0,1] space."""
|
||||||
|
if not points or len(points) < 2:
|
||||||
|
return np.arange(256, dtype=np.float64) / 255.0
|
||||||
|
|
||||||
|
pts = sorted(points, key=lambda p: p[0])
|
||||||
|
xs = np.array([p[0] for p in pts], dtype=np.float64)
|
||||||
|
ys = np.array([p[1] for p in pts], dtype=np.float64)
|
||||||
|
|
||||||
|
x_query = np.linspace(0, 1, 256)
|
||||||
|
lut = _monotone_cubic_hermite(xs, ys, x_query)
|
||||||
|
return np.clip(lut, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCurvesNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorCurves",
|
||||||
|
display_name="Color Curves",
|
||||||
|
category="image/adjustment",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.ColorCurves.Input("settings"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
|
||||||
|
rgb_pts = settings.get("rgb", [[0, 0], [1, 1]])
|
||||||
|
red_pts = settings.get("red", [[0, 0], [1, 1]])
|
||||||
|
green_pts = settings.get("green", [[0, 0], [1, 1]])
|
||||||
|
blue_pts = settings.get("blue", [[0, 0], [1, 1]])
|
||||||
|
|
||||||
|
rgb_lut = _build_lut(rgb_pts)
|
||||||
|
red_lut = _build_lut(red_pts)
|
||||||
|
green_lut = _build_lut(green_pts)
|
||||||
|
blue_lut = _build_lut(blue_pts)
|
||||||
|
|
||||||
|
# Convert to numpy for LUT application
|
||||||
|
img_np = image.cpu().numpy().copy()
|
||||||
|
|
||||||
|
# Apply per-channel curves then RGB master curve.
|
||||||
|
# Index with floor(val * 256) clamped to [0, 255] to match GPU NEAREST
|
||||||
|
# texture sampling on a 256-wide LUT texture.
|
||||||
|
for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]):
|
||||||
|
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
|
||||||
|
img_np[..., ch] = ch_lut[indices]
|
||||||
|
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
|
||||||
|
img_np[..., ch] = rgb_lut[indices]
|
||||||
|
|
||||||
|
result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype)
|
||||||
|
return io.NodeOutput(result, ui=ui.PreviewImage(result))
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCurvesExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [ColorCurvesNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ColorCurvesExtension:
|
||||||
|
return ColorCurvesExtension()
|
||||||
Reference in New Issue
Block a user