From 30c87e2a375b84208f285403ccd69d9076cda98d Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 10 Feb 2026 20:26:54 -0500 Subject: [PATCH] color correct --- comfy_api/latest/_io.py | 67 ++++++++++++++ comfy_extras/nodes_color_balance.py | 78 ++++++++++++++++ comfy_extras/nodes_color_correct.py | 88 ++++++++++++++++++ comfy_extras/nodes_color_curves.py | 137 ++++++++++++++++++++++++++++ nodes.py | 3 + 5 files changed, 373 insertions(+) create mode 100644 comfy_extras/nodes_color_balance.py create mode 100644 comfy_extras/nodes_color_correct.py create mode 100644 comfy_extras/nodes_color_curves.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 79068fca2..7606a5404 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1203,6 +1203,70 @@ class Color(ComfyTypeIO): def as_dict(self): return super().as_dict() +@comfytype(io_type="COLOR_CORRECT") +class ColorCorrect(ComfyTypeIO): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = { + "temperature": 0, + "hue": 0, + "brightness": 0, + "contrast": 0, + "saturation": 0, + "gamma": 1.0 + } + + def as_dict(self): + return super().as_dict() + +@comfytype(io_type="COLOR_BALANCE") +class ColorBalance(ComfyTypeIO): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = { + "shadows_red": 0, + "shadows_green": 0, + "shadows_blue": 0, + "midtones_red": 0, + "midtones_green": 0, + "midtones_blue": 0, + "highlights_red": 0, + "highlights_green": 0, + "highlights_blue": 0 + } + + def as_dict(self): + return super().as_dict() + +@comfytype(io_type="COLOR_CURVES") +class ColorCurves(ComfyTypeIO): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = { + "rgb": [[0, 0], [1, 1]], + "red": [[0, 0], [1, 1]], + "green": [[0, 0], [1, 1]], + "blue": [[0, 0], [1, 1]] + } + + def as_dict(self): + return super().as_dict() + @comfytype(io_type="BOUNDING_BOX") class BoundingBox(ComfyTypeIO): Type = dict @@ -2141,4 +2205,7 @@ __all__ = [ "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "ColorCorrect", + "ColorBalance", + "ColorCurves" ] diff --git a/comfy_extras/nodes_color_balance.py b/comfy_extras/nodes_color_balance.py new file mode 100644 index 000000000..e9ca9f12a --- /dev/null +++ b/comfy_extras/nodes_color_balance.py @@ -0,0 +1,78 @@ +from typing_extensions import override +import torch + +from comfy_api.latest import ComfyExtension, io, ui + + +def _smoothstep(edge0: float, edge1: float, x: torch.Tensor) -> torch.Tensor: + t = torch.clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3.0 - 2.0 * t) + + +class ColorBalanceNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ColorBalance", + display_name="Color Balance", + category="image/adjustment", + inputs=[ + io.Image.Input("image"), + io.ColorBalance.Input("settings"), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput: + shadows_red = settings.get("shadows_red", 0) + shadows_green = settings.get("shadows_green", 0) + shadows_blue = settings.get("shadows_blue", 0) + midtones_red = settings.get("midtones_red", 0) + midtones_green = settings.get("midtones_green", 0) + midtones_blue = settings.get("midtones_blue", 0) + highlights_red = settings.get("highlights_red", 0) + highlights_green = settings.get("highlights_green", 0) + highlights_blue = settings.get("highlights_blue", 0) + + result = image.clone().float() + + # Compute per-pixel luminance + luminance = ( + 0.2126 * result[..., 0] + + 0.7152 * result[..., 1] + + 0.0722 * result[..., 2] + ) + + # Compute tonal range weights + shadow_weight = 1.0 - _smoothstep(0.0, 0.5, luminance) + highlight_weight = _smoothstep(0.5, 1.0, luminance) + midtone_weight = 1.0 - shadow_weight - highlight_weight + + # Apply offsets per channel + for ch, (s, m, h) in enumerate([ + (shadows_red, midtones_red, highlights_red), + (shadows_green, midtones_green, highlights_green), + (shadows_blue, midtones_blue, highlights_blue), + ]): + offset = ( + shadow_weight * (s / 100.0) + + midtone_weight * (m / 100.0) + + highlight_weight * (h / 100.0) + ) + result[..., ch] = result[..., ch] + offset + + result = torch.clamp(result, 0, 1) + return io.NodeOutput(result, ui=ui.PreviewImage(result)) + + +class ColorBalanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorBalanceNode] + + +async def comfy_entrypoint() -> ColorBalanceExtension: + return ColorBalanceExtension() diff --git a/comfy_extras/nodes_color_correct.py b/comfy_extras/nodes_color_correct.py new file mode 100644 index 000000000..6e0c7fba0 --- /dev/null +++ b/comfy_extras/nodes_color_correct.py @@ -0,0 +1,88 @@ +from typing_extensions import override +import torch +import numpy as np + +from comfy_api.latest import ComfyExtension, io, ui + + +class ColorCorrectNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ColorCorrect", + display_name="Color Correct", + category="image/adjustment", + inputs=[ + io.Image.Input("image"), + io.ColorCorrect.Input("settings"), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput: + temperature = settings.get("temperature", 0) + hue = settings.get("hue", 0) + brightness = settings.get("brightness", 0) + contrast = settings.get("contrast", 0) + saturation = settings.get("saturation", 0) + gamma = settings.get("gamma", 1.0) + + result = image.clone() + + # Brightness: scale RGB values + if brightness != 0: + factor = 1.0 + brightness / 100.0 + result = result * factor + + # Contrast: adjust around midpoint + if contrast != 0: + factor = 1.0 + contrast / 100.0 + mean = result[..., :3].mean() + result[..., :3] = (result[..., :3] - mean) * factor + mean + + # Temperature: shift warm (red+) / cool (blue+) + if temperature != 0: + temp_factor = temperature / 100.0 + result[..., 0] = result[..., 0] + temp_factor * 0.1 # Red + result[..., 2] = result[..., 2] - temp_factor * 0.1 # Blue + + # Gamma correction + if gamma != 1.0: + result[..., :3] = torch.pow(torch.clamp(result[..., :3], 0, 1), 1.0 / gamma) + + # Saturation: convert to HSV-like space + if saturation != 0: + factor = 1.0 + saturation / 100.0 + gray = result[..., :3].mean(dim=-1, keepdim=True) + result[..., :3] = gray + (result[..., :3] - gray) * factor + + # Hue rotation: rotate in RGB space using rotation matrix + if hue != 0: + angle = np.radians(hue) + cos_a = np.cos(angle) + sin_a = np.sin(angle) + # Rodrigues' rotation formula around (1,1,1)/sqrt(3) axis + k = 1.0 / 3.0 + rotation = torch.tensor([ + [cos_a + k * (1 - cos_a), k * (1 - cos_a) - sin_a / np.sqrt(3), k * (1 - cos_a) + sin_a / np.sqrt(3)], + [k * (1 - cos_a) + sin_a / np.sqrt(3), cos_a + k * (1 - cos_a), k * (1 - cos_a) - sin_a / np.sqrt(3)], + [k * (1 - cos_a) - sin_a / np.sqrt(3), k * (1 - cos_a) + sin_a / np.sqrt(3), cos_a + k * (1 - cos_a)] + ], dtype=result.dtype, device=result.device) + rgb = result[..., :3] + result[..., :3] = torch.matmul(rgb, rotation.T) + + result = torch.clamp(result, 0, 1) + return io.NodeOutput(result, ui=ui.PreviewImage(result)) + + +class ColorCorrectExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorCorrectNode] + + +async def comfy_entrypoint() -> ColorCorrectExtension: + return ColorCorrectExtension() diff --git a/comfy_extras/nodes_color_curves.py b/comfy_extras/nodes_color_curves.py new file mode 100644 index 000000000..de9f14b87 --- /dev/null +++ b/comfy_extras/nodes_color_curves.py @@ -0,0 +1,137 @@ +from typing_extensions import override +import torch +import numpy as np + +from comfy_api.latest import ComfyExtension, io, ui + + +def _monotone_cubic_hermite(xs, ys, x_query): + """Evaluate monotone cubic Hermite interpolation at x_query points.""" + n = len(xs) + if n == 0: + return np.zeros_like(x_query) + if n == 1: + return np.full_like(x_query, ys[0]) + + # Compute slopes + deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10) + + # Compute tangents (Fritsch-Carlson) + slopes = np.zeros(n) + slopes[0] = deltas[0] + slopes[-1] = deltas[-1] + for i in range(1, n - 1): + if deltas[i - 1] * deltas[i] <= 0: + slopes[i] = 0 + else: + slopes[i] = (deltas[i - 1] + deltas[i]) / 2 + + # Enforce monotonicity + for i in range(n - 1): + if deltas[i] == 0: + slopes[i] = 0 + slopes[i + 1] = 0 + else: + alpha = slopes[i] / deltas[i] + beta = slopes[i + 1] / deltas[i] + s = alpha ** 2 + beta ** 2 + if s > 9: + t = 3 / np.sqrt(s) + slopes[i] = t * alpha * deltas[i] + slopes[i + 1] = t * beta * deltas[i] + + # Evaluate + result = np.zeros_like(x_query, dtype=np.float64) + indices = np.searchsorted(xs, x_query, side='right') - 1 + indices = np.clip(indices, 0, n - 2) + + for i in range(n - 1): + mask = indices == i + if not np.any(mask): + continue + dx = xs[i + 1] - xs[i] + if dx == 0: + result[mask] = ys[i] + continue + t = (x_query[mask] - xs[i]) / dx + t2 = t * t + t3 = t2 * t + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1] + + # Clamp edges + result[x_query <= xs[0]] = ys[0] + result[x_query >= xs[-1]] = ys[-1] + + return result + + +def _build_lut(points): + """Build a 256-entry LUT from curve control points in [0,1] space.""" + if not points or len(points) < 2: + return np.arange(256, dtype=np.float64) / 255.0 + + pts = sorted(points, key=lambda p: p[0]) + xs = np.array([p[0] for p in pts], dtype=np.float64) + ys = np.array([p[1] for p in pts], dtype=np.float64) + + x_query = np.linspace(0, 1, 256) + lut = _monotone_cubic_hermite(xs, ys, x_query) + return np.clip(lut, 0, 1) + + +class ColorCurvesNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ColorCurves", + display_name="Color Curves", + category="image/adjustment", + inputs=[ + io.Image.Input("image"), + io.ColorCurves.Input("settings"), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput: + rgb_pts = settings.get("rgb", [[0, 0], [1, 1]]) + red_pts = settings.get("red", [[0, 0], [1, 1]]) + green_pts = settings.get("green", [[0, 0], [1, 1]]) + blue_pts = settings.get("blue", [[0, 0], [1, 1]]) + + rgb_lut = _build_lut(rgb_pts) + red_lut = _build_lut(red_pts) + green_lut = _build_lut(green_pts) + blue_lut = _build_lut(blue_pts) + + # Convert to numpy for LUT application + img_np = image.cpu().numpy().copy() + + # Apply per-channel curves then RGB master curve + for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]): + # Per-channel curve + indices = np.clip(img_np[..., ch] * 255, 0, 255).astype(np.int32) + img_np[..., ch] = ch_lut[indices] + # RGB master curve + indices = np.clip(img_np[..., ch] * 255, 0, 255).astype(np.int32) + img_np[..., ch] = rgb_lut[indices] + + result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype) + return io.NodeOutput(result, ui=ui.PreviewImage(result)) + + +class ColorCurvesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorCurvesNode] + + +async def comfy_entrypoint() -> ColorCurvesExtension: + return ColorCurvesExtension() diff --git a/nodes.py b/nodes.py index 91de7a9d7..9cf167b25 100644 --- a/nodes.py +++ b/nodes.py @@ -2435,6 +2435,9 @@ async def init_builtin_extra_nodes(): "nodes_lora_debug.py", "nodes_color.py", "nodes_toolkit.py", + "nodes_color_correct.py", + "nodes_color_balance.py", + "nodes_color_curves.py" ] import_failed = []