From e333ad459eb1108019d03455c91ebe8450636a10 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 24 Feb 2026 00:19:06 -0500 Subject: [PATCH] feat: add CurveEditor component (#8860) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Prerequisite for upcoming native color correction nodes (ColorCurves). Reusable curve editor with monotone cubic Hermite interpolation, drag-to-add/move/delete control points, and SVG-based rendering. Includes CurvePoint type, LUT generation utility, and useCurveEditor composable for interaction logic. ## Screenshots (if applicable) https://github.com/user-attachments/assets/948352c7-bdf2-40f9-a8f0-35bc2b2f3202 ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-8860-feat-add-CurveEditor-component-and-d3-shape-dependency-3076d73d3650817f8421f98e349569d0) by [Unito](https://www.unito.io) --- src/components/curve/CurveEditor.test.ts | 113 ++++++++++++++ src/components/curve/CurveEditor.vue | 103 +++++++++++++ src/components/curve/WidgetCurve.vue | 16 ++ src/components/curve/curveUtils.test.ts | 141 ++++++++++++++++++ src/components/curve/curveUtils.ts | 120 +++++++++++++++ src/composables/useCurveEditor.ts | 140 +++++++++++++++++ src/lib/litegraph/src/types/widgets.ts | 8 + src/lib/litegraph/src/widgets/CurveWidget.ts | 16 ++ src/lib/litegraph/src/widgets/widgetMap.ts | 4 + .../widgets/composables/useCurveWidget.ts | 30 ++++ .../widgets/registry/widgetRegistry.ts | 13 +- src/schemas/nodeDef/nodeDefSchemaV2.ts | 11 ++ src/scripts/widgets.ts | 2 + 13 files changed, 716 insertions(+), 1 deletion(-) create mode 100644 src/components/curve/CurveEditor.test.ts create mode 100644 src/components/curve/CurveEditor.vue create mode 100644 src/components/curve/WidgetCurve.vue create mode 100644 src/components/curve/curveUtils.test.ts create mode 100644 src/components/curve/curveUtils.ts create mode 100644 src/composables/useCurveEditor.ts create mode 100644 src/lib/litegraph/src/widgets/CurveWidget.ts create mode 100644 src/renderer/extensions/vueNodes/widgets/composables/useCurveWidget.ts diff --git a/src/components/curve/CurveEditor.test.ts b/src/components/curve/CurveEditor.test.ts new file mode 100644 index 0000000000..f1cd2462b5 --- /dev/null +++ b/src/components/curve/CurveEditor.test.ts @@ -0,0 +1,113 @@ +import { mount } from '@vue/test-utils' +import { describe, expect, it } from 'vitest' + +import type { CurvePoint } from '@/lib/litegraph/src/types/widgets' + +import CurveEditor from './CurveEditor.vue' + +function mountEditor(points: CurvePoint[], extraProps = {}) { + return mount(CurveEditor, { + props: { modelValue: points, ...extraProps } + }) +} + +function getCurvePath(wrapper: ReturnType) { + return wrapper.find('[data-testid="curve-path"]') +} + +describe('CurveEditor', () => { + it('renders SVG with curve path', () => { + const wrapper = mountEditor([ + [0, 0], + [1, 1] + ]) + expect(wrapper.find('svg').exists()).toBe(true) + const curvePath = getCurvePath(wrapper) + expect(curvePath.exists()).toBe(true) + expect(curvePath.attributes('d')).toBeTruthy() + }) + + it('renders a circle for each control point', () => { + const wrapper = mountEditor([ + [0, 0], + [0.5, 0.7], + [1, 1] + ]) + expect(wrapper.findAll('circle')).toHaveLength(3) + }) + + it('renders histogram path when provided', () => { + const histogram = new Uint32Array(256) + for (let i = 0; i < 256; i++) histogram[i] = i + 1 + const wrapper = mountEditor( + [ + [0, 0], + [1, 1] + ], + { histogram } + ) + const histogramPath = wrapper.find('[data-testid="histogram-path"]') + expect(histogramPath.exists()).toBe(true) + expect(histogramPath.attributes('d')).toContain('M0,1') + }) + + it('does not render histogram path when not provided', () => { + const wrapper = mountEditor([ + [0, 0], + [1, 1] + ]) + expect(wrapper.find('[data-testid="histogram-path"]').exists()).toBe(false) + }) + + it('returns empty path with fewer than 2 points', () => { + const wrapper = mountEditor([[0.5, 0.5]]) + expect(getCurvePath(wrapper).attributes('d')).toBe('') + }) + + it('generates path starting with M and containing L segments', () => { + const wrapper = mountEditor([ + [0, 0], + [0.5, 0.8], + [1, 1] + ]) + const d = getCurvePath(wrapper).attributes('d')! + expect(d).toMatch(/^M/) + expect(d).toContain('L') + }) + + it('curve path only spans the x-range of control points', () => { + const wrapper = mountEditor([ + [0.2, 0.3], + [0.8, 0.9] + ]) + const d = getCurvePath(wrapper).attributes('d')! + const xValues = d + .split(/[ML]/) + .filter(Boolean) + .map((s) => parseFloat(s.split(',')[0])) + expect(Math.min(...xValues)).toBeCloseTo(0.2, 2) + expect(Math.max(...xValues)).toBeCloseTo(0.8, 2) + }) + + it('deletes a point on right-click but keeps minimum 2', async () => { + const points: CurvePoint[] = [ + [0, 0], + [0.5, 0.5], + [1, 1] + ] + const wrapper = mountEditor(points) + expect(wrapper.findAll('circle')).toHaveLength(3) + + await wrapper.findAll('circle')[1].trigger('pointerdown', { + button: 2, + pointerId: 1 + }) + expect(wrapper.findAll('circle')).toHaveLength(2) + + await wrapper.findAll('circle')[0].trigger('pointerdown', { + button: 2, + pointerId: 1 + }) + expect(wrapper.findAll('circle')).toHaveLength(2) + }) +}) diff --git a/src/components/curve/CurveEditor.vue b/src/components/curve/CurveEditor.vue new file mode 100644 index 0000000000..751d814b4e --- /dev/null +++ b/src/components/curve/CurveEditor.vue @@ -0,0 +1,103 @@ + + + diff --git a/src/components/curve/WidgetCurve.vue b/src/components/curve/WidgetCurve.vue new file mode 100644 index 0000000000..1617a44b0c --- /dev/null +++ b/src/components/curve/WidgetCurve.vue @@ -0,0 +1,16 @@ + + + diff --git a/src/components/curve/curveUtils.test.ts b/src/components/curve/curveUtils.test.ts new file mode 100644 index 0000000000..c6f6ff9db6 --- /dev/null +++ b/src/components/curve/curveUtils.test.ts @@ -0,0 +1,141 @@ +import { describe, expect, it } from 'vitest' + +import type { CurvePoint } from '@/lib/litegraph/src/types/widgets' + +import { + createMonotoneInterpolator, + curvesToLUT, + histogramToPath +} from './curveUtils' + +describe('createMonotoneInterpolator', () => { + it('returns 0 for empty points', () => { + const interpolate = createMonotoneInterpolator([]) + expect(interpolate(0.5)).toBe(0) + }) + + it('returns constant for single point', () => { + const interpolate = createMonotoneInterpolator([[0.5, 0.7]]) + expect(interpolate(0)).toBe(0.7) + expect(interpolate(1)).toBe(0.7) + }) + + it('passes through control points exactly', () => { + const points: CurvePoint[] = [ + [0, 0], + [0.5, 0.8], + [1, 1] + ] + const interpolate = createMonotoneInterpolator(points) + expect(interpolate(0)).toBeCloseTo(0, 5) + expect(interpolate(0.5)).toBeCloseTo(0.8, 5) + expect(interpolate(1)).toBeCloseTo(1, 5) + }) + + it('clamps to endpoint values outside range', () => { + const points: CurvePoint[] = [ + [0.2, 0.3], + [0.8, 0.9] + ] + const interpolate = createMonotoneInterpolator(points) + expect(interpolate(0)).toBe(0.3) + expect(interpolate(1)).toBe(0.9) + }) + + it('produces monotone output for monotone input', () => { + const points: CurvePoint[] = [ + [0, 0], + [0.25, 0.2], + [0.5, 0.5], + [0.75, 0.8], + [1, 1] + ] + const interpolate = createMonotoneInterpolator(points) + + let prev = -Infinity + for (let x = 0; x <= 1; x += 0.01) { + const y = interpolate(x) + expect(y).toBeGreaterThanOrEqual(prev) + prev = y + } + }) + + it('handles unsorted input points', () => { + const points: CurvePoint[] = [ + [1, 1], + [0, 0], + [0.5, 0.5] + ] + const interpolate = createMonotoneInterpolator(points) + expect(interpolate(0)).toBeCloseTo(0, 5) + expect(interpolate(0.5)).toBeCloseTo(0.5, 5) + expect(interpolate(1)).toBeCloseTo(1, 5) + }) +}) + +describe('curvesToLUT', () => { + it('returns a 256-entry Uint8Array', () => { + const lut = curvesToLUT([ + [0, 0], + [1, 1] + ]) + expect(lut).toBeInstanceOf(Uint8Array) + expect(lut.length).toBe(256) + }) + + it('produces identity LUT for diagonal curve', () => { + const lut = curvesToLUT([ + [0, 0], + [1, 1] + ]) + for (let i = 0; i < 256; i++) { + expect(lut[i]).toBeCloseTo(i, 0) + } + }) + + it('clamps output to [0, 255]', () => { + const lut = curvesToLUT([ + [0, 0], + [0.5, 1.5], + [1, 1] + ]) + for (let i = 0; i < 256; i++) { + expect(lut[i]).toBeGreaterThanOrEqual(0) + expect(lut[i]).toBeLessThanOrEqual(255) + } + }) +}) + +describe('histogramToPath', () => { + it('returns empty string for empty histogram', () => { + expect(histogramToPath(new Uint32Array(0))).toBe('') + }) + + it('returns empty string when all bins are zero', () => { + expect(histogramToPath(new Uint32Array(256))).toBe('') + }) + + it('returns a closed SVG path for valid histogram', () => { + const histogram = new Uint32Array(256) + for (let i = 0; i < 256; i++) histogram[i] = i + 1 + const path = histogramToPath(histogram) + expect(path).toMatch(/^M0,1/) + expect(path).toMatch(/L1,1 Z$/) + }) + + it('normalizes using 99.5th percentile to suppress outliers', () => { + const histogram = new Uint32Array(256) + for (let i = 0; i < 256; i++) histogram[i] = 100 + histogram[255] = 100000 + const path = histogramToPath(histogram) + // Most bins should map to y=0 (1 - 100/100 = 0) since + // the 99.5th percentile is 100, not the outlier 100000 + const yValues = path + .split(/[ML]/) + .filter(Boolean) + .map((s) => parseFloat(s.split(',')[1])) + .filter((y) => !isNaN(y)) + const nearZero = yValues.filter((y) => Math.abs(y) < 0.01) + expect(nearZero.length).toBeGreaterThan(200) + }) +}) diff --git a/src/components/curve/curveUtils.ts b/src/components/curve/curveUtils.ts new file mode 100644 index 0000000000..7a3f1f17ac --- /dev/null +++ b/src/components/curve/curveUtils.ts @@ -0,0 +1,120 @@ +import type { CurvePoint } from '@/lib/litegraph/src/types/widgets' + +/** + * Monotone cubic Hermite interpolation. + * Produces a smooth curve that passes through all control points + * without overshooting (monotone property). + * + * Returns a function that evaluates y for any x in [0, 1]. + */ +export function createMonotoneInterpolator( + points: CurvePoint[] +): (x: number) => number { + if (points.length === 0) return () => 0 + if (points.length === 1) return () => points[0][1] + + const sorted = [...points].sort((a, b) => a[0] - b[0]) + const n = sorted.length + const xs = sorted.map((p) => p[0]) + const ys = sorted.map((p) => p[1]) + + const deltas: number[] = [] + const slopes: number[] = [] + for (let i = 0; i < n - 1; i++) { + const dx = xs[i + 1] - xs[i] + deltas.push(dx === 0 ? 0 : (ys[i + 1] - ys[i]) / dx) + } + + slopes.push(deltas[0] ?? 0) + for (let i = 1; i < n - 1; i++) { + if (deltas[i - 1] * deltas[i] <= 0) { + slopes.push(0) + } else { + slopes.push((deltas[i - 1] + deltas[i]) / 2) + } + } + slopes.push(deltas[n - 2] ?? 0) + + for (let i = 0; i < n - 1; i++) { + if (deltas[i] === 0) { + slopes[i] = 0 + slopes[i + 1] = 0 + } else { + const alpha = slopes[i] / deltas[i] + const beta = slopes[i + 1] / deltas[i] + const s = alpha * alpha + beta * beta + if (s > 9) { + const t = 3 / Math.sqrt(s) + slopes[i] = t * alpha * deltas[i] + slopes[i + 1] = t * beta * deltas[i] + } + } + } + + return (x: number): number => { + if (x <= xs[0]) return ys[0] + if (x >= xs[n - 1]) return ys[n - 1] + + let lo = 0 + let hi = n - 1 + while (lo < hi - 1) { + const mid = (lo + hi) >> 1 + if (xs[mid] <= x) lo = mid + else hi = mid + } + + const dx = xs[hi] - xs[lo] + if (dx === 0) return ys[lo] + + const t = (x - xs[lo]) / dx + const t2 = t * t + const t3 = t2 * t + + const h00 = 2 * t3 - 3 * t2 + 1 + const h10 = t3 - 2 * t2 + t + const h01 = -2 * t3 + 3 * t2 + const h11 = t3 - t2 + + return ( + h00 * ys[lo] + + h10 * dx * slopes[lo] + + h01 * ys[hi] + + h11 * dx * slopes[hi] + ) + } +} + +/** + * Convert a 256-bin histogram into an SVG path string. + * Normalizes using the 99.5th percentile to avoid outlier spikes. + */ +export function histogramToPath(histogram: Uint32Array): string { + if (!histogram.length) return '' + + const sorted = Array.from(histogram).sort((a, b) => a - b) + const max = sorted[Math.floor(255 * 0.995)] + if (max === 0) return '' + + const step = 1 / 255 + let d = 'M0,1' + for (let i = 0; i < 256; i++) { + const x = i * step + const y = 1 - Math.min(1, histogram[i] / max) + d += ` L${x.toFixed(4)},${y.toFixed(4)}` + } + d += ' L1,1 Z' + return d +} + +export function curvesToLUT(points: CurvePoint[]): Uint8Array { + const lut = new Uint8Array(256) + const interpolate = createMonotoneInterpolator(points) + + for (let i = 0; i < 256; i++) { + const x = i / 255 + const y = interpolate(x) + lut[i] = Math.max(0, Math.min(255, Math.round(y * 255))) + } + + return lut +} diff --git a/src/composables/useCurveEditor.ts b/src/composables/useCurveEditor.ts new file mode 100644 index 0000000000..0f6029d884 --- /dev/null +++ b/src/composables/useCurveEditor.ts @@ -0,0 +1,140 @@ +import { computed, onBeforeUnmount, ref } from 'vue' +import type { Ref } from 'vue' + +import { createMonotoneInterpolator } from '@/components/curve/curveUtils' +import type { CurvePoint } from '@/lib/litegraph/src/types/widgets' + +interface UseCurveEditorOptions { + svgRef: Ref + modelValue: Ref +} + +export function useCurveEditor({ svgRef, modelValue }: UseCurveEditorOptions) { + const dragIndex = ref(-1) + let cleanupDrag: (() => void) | null = null + + const curvePath = computed(() => { + const points = modelValue.value + if (points.length < 2) return '' + + const interpolate = createMonotoneInterpolator(points) + const xMin = points[0][0] + const xMax = points[points.length - 1][0] + const segments = 128 + const parts: string[] = [] + for (let i = 0; i <= segments; i++) { + const x = xMin + (xMax - xMin) * (i / segments) + const y = 1 - interpolate(x) + parts.push(`${i === 0 ? 'M' : 'L'}${x.toFixed(4)},${y.toFixed(4)}`) + } + return parts.join('') + }) + + function svgCoords(e: PointerEvent): [number, number] { + const svg = svgRef.value + if (!svg) return [0, 0] + + const ctm = svg.getScreenCTM() + if (!ctm) return [0, 0] + + const svgPt = new DOMPoint(e.clientX, e.clientY).matrixTransform( + ctm.inverse() + ) + return [ + Math.max(0, Math.min(1, svgPt.x)), + Math.max(0, Math.min(1, 1 - svgPt.y)) + ] + } + + function findNearestPoint(x: number, y: number): number { + const threshold2 = 0.04 * 0.04 + let nearest = -1 + let minDist2 = threshold2 + for (let i = 0; i < modelValue.value.length; i++) { + const dx = modelValue.value[i][0] - x + const dy = modelValue.value[i][1] - y + const dist2 = dx * dx + dy * dy + if (dist2 < minDist2) { + minDist2 = dist2 + nearest = i + } + } + return nearest + } + + function handleSvgPointerDown(e: PointerEvent) { + if (e.button !== 0) return + + const [x, y] = svgCoords(e) + + const nearby = findNearestPoint(x, y) + if (nearby >= 0) { + startDrag(nearby, e) + return + } + + if (e.ctrlKey) return + + const newPoint: CurvePoint = [x, y] + const newPoints: CurvePoint[] = [...modelValue.value, newPoint] + newPoints.sort((a, b) => a[0] - b[0]) + modelValue.value = newPoints + + startDrag(newPoints.indexOf(newPoint), e) + } + + function startDrag(index: number, e: PointerEvent) { + cleanupDrag?.() + + if (e.button === 2 || (e.button === 0 && e.ctrlKey)) { + if (modelValue.value.length > 2) { + const newPoints = [...modelValue.value] + newPoints.splice(index, 1) + modelValue.value = newPoints + } + return + } + + dragIndex.value = index + const svg = svgRef.value + if (!svg) return + + svg.setPointerCapture(e.pointerId) + + const onMove = (ev: PointerEvent) => { + if (dragIndex.value < 0) return + const [x, y] = svgCoords(ev) + const movedPoint: CurvePoint = [x, y] + const newPoints = [...modelValue.value] + newPoints[dragIndex.value] = movedPoint + newPoints.sort((a, b) => a[0] - b[0]) + modelValue.value = newPoints + dragIndex.value = newPoints.indexOf(movedPoint) + } + + const endDrag = () => { + if (dragIndex.value < 0) return + dragIndex.value = -1 + svg.removeEventListener('pointermove', onMove) + svg.removeEventListener('pointerup', endDrag) + svg.removeEventListener('lostpointercapture', endDrag) + cleanupDrag = null + } + + cleanupDrag = endDrag + + svg.addEventListener('pointermove', onMove) + svg.addEventListener('pointerup', endDrag) + svg.addEventListener('lostpointercapture', endDrag) + } + + onBeforeUnmount(() => { + cleanupDrag?.() + }) + + return { + curvePath, + handleSvgPointerDown, + startDrag + } +} diff --git a/src/lib/litegraph/src/types/widgets.ts b/src/lib/litegraph/src/types/widgets.ts index b093ca59c1..63e08f02e3 100644 --- a/src/lib/litegraph/src/types/widgets.ts +++ b/src/lib/litegraph/src/types/widgets.ts @@ -136,6 +136,7 @@ export type IWidget = | IAssetWidget | IImageCropWidget | IBoundingBoxWidget + | ICurveWidget export interface IBooleanWidget extends IBaseWidget { type: 'toggle' @@ -328,6 +329,13 @@ export interface IBoundingBoxWidget extends IBaseWidget { value: Bounds } +export type CurvePoint = [x: number, y: number] + +export interface ICurveWidget extends IBaseWidget { + type: 'curve' + value: CurvePoint[] +} + /** * Valid widget types. TS cannot provide easily extensible type safety for this at present. * Override linkedWidgets[] diff --git a/src/lib/litegraph/src/widgets/CurveWidget.ts b/src/lib/litegraph/src/widgets/CurveWidget.ts new file mode 100644 index 0000000000..2879ae0918 --- /dev/null +++ b/src/lib/litegraph/src/widgets/CurveWidget.ts @@ -0,0 +1,16 @@ +import type { ICurveWidget } from '../types/widgets' +import { BaseWidget } from './BaseWidget' +import type { DrawWidgetOptions, WidgetEventOptions } from './BaseWidget' + +export class CurveWidget + extends BaseWidget + implements ICurveWidget +{ + override type = 'curve' as const + + drawWidget(ctx: CanvasRenderingContext2D, options: DrawWidgetOptions): void { + this.drawVueOnlyWarning(ctx, options, 'Curve') + } + + onClick(_options: WidgetEventOptions): void {} +} diff --git a/src/lib/litegraph/src/widgets/widgetMap.ts b/src/lib/litegraph/src/widgets/widgetMap.ts index c145494a19..704fdc48a7 100644 --- a/src/lib/litegraph/src/widgets/widgetMap.ts +++ b/src/lib/litegraph/src/widgets/widgetMap.ts @@ -16,6 +16,7 @@ import { ButtonWidget } from './ButtonWidget' import { ChartWidget } from './ChartWidget' import { ColorWidget } from './ColorWidget' import { ComboWidget } from './ComboWidget' +import { CurveWidget } from './CurveWidget' import { FileUploadWidget } from './FileUploadWidget' import { GalleriaWidget } from './GalleriaWidget' import { GradientSliderWidget } from './GradientSliderWidget' @@ -56,6 +57,7 @@ export type WidgetTypeMap = { asset: AssetWidget imagecrop: ImageCropWidget boundingbox: BoundingBoxWidget + curve: CurveWidget [key: string]: BaseWidget } @@ -132,6 +134,8 @@ export function toConcreteWidget( return toClass(ImageCropWidget, narrowedWidget, node) case 'boundingbox': return toClass(BoundingBoxWidget, narrowedWidget, node) + case 'curve': + return toClass(CurveWidget, narrowedWidget, node) default: { if (wrapLegacyWidgets) return toClass(LegacyWidget, widget, node) } diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useCurveWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useCurveWidget.ts new file mode 100644 index 0000000000..42e3b30006 --- /dev/null +++ b/src/renderer/extensions/vueNodes/widgets/composables/useCurveWidget.ts @@ -0,0 +1,30 @@ +import type { LGraphNode } from '@/lib/litegraph/src/litegraph' +import type { ICurveWidget } from '@/lib/litegraph/src/types/widgets' +import type { + CurveInputSpec, + InputSpec as InputSpecV2 +} from '@/schemas/nodeDef/nodeDefSchemaV2' +import type { ComfyWidgetConstructorV2 } from '@/scripts/widgets' + +export const useCurveWidget = (): ComfyWidgetConstructorV2 => { + return (node: LGraphNode, inputSpec: InputSpecV2): ICurveWidget => { + const spec = inputSpec as CurveInputSpec + const defaultValue = spec.default ?? [ + [0, 0], + [1, 1] + ] + + const rawWidget = node.addWidget( + 'curve', + spec.name, + [...defaultValue], + () => {} + ) + + if (rawWidget.type !== 'curve') { + throw new Error(`Unexpected widget type: ${rawWidget.type}`) + } + + return rawWidget as ICurveWidget + } +} diff --git a/src/renderer/extensions/vueNodes/widgets/registry/widgetRegistry.ts b/src/renderer/extensions/vueNodes/widgets/registry/widgetRegistry.ts index 5f8903008e..8d4129948f 100644 --- a/src/renderer/extensions/vueNodes/widgets/registry/widgetRegistry.ts +++ b/src/renderer/extensions/vueNodes/widgets/registry/widgetRegistry.ts @@ -57,6 +57,9 @@ const WidgetImageCrop = defineAsyncComponent( const WidgetBoundingBox = defineAsyncComponent( () => import('@/components/boundingbox/WidgetBoundingBox.vue') ) +const WidgetCurve = defineAsyncComponent( + () => import('@/components/curve/WidgetCurve.vue') +) export const FOR_TESTING = { WidgetButton, @@ -175,6 +178,14 @@ const coreWidgetDefinitions: Array<[string, WidgetDefinition]> = [ aliases: ['BOUNDING_BOX'], essential: false } + ], + [ + 'curve', + { + component: WidgetCurve, + aliases: ['CURVE'], + essential: false + } ] ] @@ -206,7 +217,7 @@ export const shouldRenderAsVue = (widget: Partial): boolean => { return !widget.options?.canvasOnly && !!widget.type } -const EXPANDING_TYPES = ['textarea', 'markdown', 'load3D'] as const +const EXPANDING_TYPES = ['textarea', 'markdown', 'load3D', 'curve'] as const export function shouldExpand(type: string): boolean { const canonicalType = getCanonicalType(type) diff --git a/src/schemas/nodeDef/nodeDefSchemaV2.ts b/src/schemas/nodeDef/nodeDefSchemaV2.ts index c9ecc88a5f..b0be538af1 100644 --- a/src/schemas/nodeDef/nodeDefSchemaV2.ts +++ b/src/schemas/nodeDef/nodeDefSchemaV2.ts @@ -126,6 +126,15 @@ const zTextareaInputSpec = zBaseInputOptions.extend({ .optional() }) +const zCurvePoint = z.tuple([z.number(), z.number()]) + +const zCurveInputSpec = zBaseInputOptions.extend({ + type: z.literal('CURVE'), + name: z.string(), + isOptional: z.boolean().optional(), + default: z.array(zCurvePoint).optional() +}) + const zCustomInputSpec = zBaseInputOptions.extend({ type: z.string(), name: z.string(), @@ -146,6 +155,7 @@ const zInputSpec = z.union([ zChartInputSpec, zGalleriaInputSpec, zTextareaInputSpec, + zCurveInputSpec, zCustomInputSpec ]) @@ -190,6 +200,7 @@ export type BoundingBoxInputSpec = z.infer export type ChartInputSpec = z.infer export type GalleriaInputSpec = z.infer export type TextareaInputSpec = z.infer +export type CurveInputSpec = z.infer export type CustomInputSpec = z.infer export type InputSpec = z.infer diff --git a/src/scripts/widgets.ts b/src/scripts/widgets.ts index 071ebff3ee..3087a00b9d 100644 --- a/src/scripts/widgets.ts +++ b/src/scripts/widgets.ts @@ -9,6 +9,7 @@ import { useSettingStore } from '@/platform/settings/settingStore' import { dynamicWidgets } from '@/core/graph/widgets/dynamicWidgets' import { useBooleanWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget' import { useBoundingBoxWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget' +import { useCurveWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useCurveWidget' import { useChartWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useChartWidget' import { useColorWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useColorWidget' import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget' @@ -306,6 +307,7 @@ export const ComfyWidgets = { CHART: transformWidgetConstructorV2ToV1(useChartWidget()), GALLERIA: transformWidgetConstructorV2ToV1(useGalleriaWidget()), TEXTAREA: transformWidgetConstructorV2ToV1(useTextareaWidget()), + CURVE: transformWidgetConstructorV2ToV1(useCurveWidget()), ...dynamicWidgets } as const