mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-04-20 14:30:41 +00:00
feat: add linear interpolation type to CURVE widget (#10118)
## Summary
Change the CURVE widget value from CurvePoint[] to CurveData ({ points,
interpolation }) to support multiple interpolation types. Add a Select
dropdown in the widget UI for switching between Smooth (monotone cubic)
and Linear interpolation, with the SVG preview updating accordingly.
- Add CurveData type with CURVE_INTERPOLATIONS const enum
- Add createLinearInterpolator with piecewise linear + binary search
- Add createInterpolator factory dispatching by interpolation type
- Add isCurveData type guard in curveUtils
- Update ICurveWidget value type to CurveData
- Add interpolation prop to CurveEditor and useCurveEditor composable
- Linear mode generates direct M...L... SVG path (no sampling)
- Add i18n entries for interpolation labels
- Add unit tests for createLinearInterpolator
BE change is https://github.com/Comfy-Org/ComfyUI/pull/12757
## Screenshots (if applicable)
<img width="1437" height="670" alt="image"
src="https://github.com/user-attachments/assets/550aedec-e5da-425b-8233-86a4f28067fa"
/>
<img width="1445" height="648" alt="image"
src="https://github.com/user-attachments/assets/0a8dc654-3f92-4ca2-9fa2-c1fef3be6d66"
/>
┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-10118-feat-add-linear-interpolation-type-to-CURVE-widget-3256d73d36508185a86edf73bb555c51)
by [Unito](https://www.unito.io)
This commit is contained in:
@@ -82,23 +82,25 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, useTemplateRef } from 'vue'
|
||||
import { computed, toRef, useTemplateRef } from 'vue'
|
||||
|
||||
import { useCurveEditor } from '@/composables/useCurveEditor'
|
||||
import { cn } from '@/utils/tailwindUtil'
|
||||
|
||||
import type { CurvePoint } from './types'
|
||||
import type { CurveInterpolation, CurvePoint } from './types'
|
||||
|
||||
import { histogramToPath } from './curveUtils'
|
||||
|
||||
const {
|
||||
curveColor = 'white',
|
||||
histogram,
|
||||
disabled = false
|
||||
disabled = false,
|
||||
interpolation = 'monotone_cubic'
|
||||
} = defineProps<{
|
||||
curveColor?: string
|
||||
histogram?: Uint32Array | null
|
||||
disabled?: boolean
|
||||
interpolation?: CurveInterpolation
|
||||
}>()
|
||||
|
||||
const modelValue = defineModel<CurvePoint[]>({
|
||||
@@ -109,7 +111,8 @@ const svgRef = useTemplateRef<SVGSVGElement>('svgRef')
|
||||
|
||||
const { curvePath, handleSvgPointerDown, startDrag } = useCurveEditor({
|
||||
svgRef,
|
||||
modelValue
|
||||
modelValue,
|
||||
interpolation: toRef(() => interpolation)
|
||||
})
|
||||
|
||||
function onSvgPointerDown(e: PointerEvent) {
|
||||
|
||||
@@ -1,9 +1,30 @@
|
||||
<template>
|
||||
<CurveEditor
|
||||
:model-value="effectivePoints"
|
||||
:disabled="isDisabled"
|
||||
@update:model-value="modelValue = $event"
|
||||
/>
|
||||
<div class="flex flex-col gap-1">
|
||||
<Select
|
||||
v-if="!isDisabled"
|
||||
:model-value="modelValue.interpolation"
|
||||
@update:model-value="onInterpolationChange"
|
||||
>
|
||||
<SelectTrigger size="md">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="interp in CURVE_INTERPOLATIONS"
|
||||
:key="interp"
|
||||
:value="interp"
|
||||
>
|
||||
{{ $t(`curveWidget.${interp}`) }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<CurveEditor
|
||||
:model-value="effectiveCurve.points"
|
||||
:disabled="isDisabled"
|
||||
:interpolation="effectiveCurve.interpolation"
|
||||
@update:model-value="onPointsChange"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
@@ -15,31 +36,53 @@ import {
|
||||
} from '@/composables/useUpstreamValue'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
import Select from '@/components/ui/select/Select.vue'
|
||||
import SelectContent from '@/components/ui/select/SelectContent.vue'
|
||||
import SelectItem from '@/components/ui/select/SelectItem.vue'
|
||||
import SelectTrigger from '@/components/ui/select/SelectTrigger.vue'
|
||||
import SelectValue from '@/components/ui/select/SelectValue.vue'
|
||||
|
||||
import CurveEditor from './CurveEditor.vue'
|
||||
import { isCurvePointArray } from './curveUtils'
|
||||
import type { CurvePoint } from './types'
|
||||
import { isCurveData } from './curveUtils'
|
||||
import { CURVE_INTERPOLATIONS } from './types'
|
||||
import type { CurveData, CurveInterpolation, CurvePoint } from './types'
|
||||
|
||||
const { widget } = defineProps<{
|
||||
widget: SimplifiedWidget
|
||||
}>()
|
||||
|
||||
const modelValue = defineModel<CurvePoint[]>({
|
||||
default: () => [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
]
|
||||
const modelValue = defineModel<CurveData>({
|
||||
default: () => ({
|
||||
points: [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
],
|
||||
interpolation: 'monotone_cubic'
|
||||
})
|
||||
})
|
||||
|
||||
const isDisabled = computed(() => !!widget.options?.disabled)
|
||||
|
||||
const upstreamValue = useUpstreamValue(
|
||||
() => widget.linkedUpstream,
|
||||
singleValueExtractor(isCurvePointArray)
|
||||
singleValueExtractor(isCurveData)
|
||||
)
|
||||
|
||||
const effectivePoints = computed(() =>
|
||||
const effectiveCurve = computed(() =>
|
||||
isDisabled.value && upstreamValue.value
|
||||
? upstreamValue.value
|
||||
: modelValue.value
|
||||
)
|
||||
|
||||
function onPointsChange(points: CurvePoint[]) {
|
||||
modelValue.value = { ...modelValue.value, points }
|
||||
}
|
||||
|
||||
function onInterpolationChange(value: unknown) {
|
||||
if (typeof value !== 'string') return
|
||||
modelValue.value = {
|
||||
...modelValue.value,
|
||||
interpolation: value as CurveInterpolation
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -3,6 +3,7 @@ import { describe, expect, it } from 'vitest'
|
||||
import type { CurvePoint } from './types'
|
||||
|
||||
import {
|
||||
createLinearInterpolator,
|
||||
createMonotoneInterpolator,
|
||||
curvesToLUT,
|
||||
histogramToPath
|
||||
@@ -73,6 +74,64 @@ describe('createMonotoneInterpolator', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('createLinearInterpolator', () => {
|
||||
it('returns 0 for empty points', () => {
|
||||
const interpolate = createLinearInterpolator([])
|
||||
expect(interpolate(0.5)).toBe(0)
|
||||
})
|
||||
|
||||
it('returns constant for single point', () => {
|
||||
const interpolate = createLinearInterpolator([[0.5, 0.7]])
|
||||
expect(interpolate(0)).toBe(0.7)
|
||||
expect(interpolate(1)).toBe(0.7)
|
||||
})
|
||||
|
||||
it('passes through control points exactly', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0, 0],
|
||||
[0.5, 0.8],
|
||||
[1, 1]
|
||||
]
|
||||
const interpolate = createLinearInterpolator(points)
|
||||
expect(interpolate(0)).toBe(0)
|
||||
expect(interpolate(0.5)).toBeCloseTo(0.8, 10)
|
||||
expect(interpolate(1)).toBe(1)
|
||||
})
|
||||
|
||||
it('linearly interpolates between points', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
]
|
||||
const interpolate = createLinearInterpolator(points)
|
||||
expect(interpolate(0.25)).toBeCloseTo(0.25, 10)
|
||||
expect(interpolate(0.5)).toBeCloseTo(0.5, 10)
|
||||
expect(interpolate(0.75)).toBeCloseTo(0.75, 10)
|
||||
})
|
||||
|
||||
it('clamps to endpoint values outside range', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0.2, 0.3],
|
||||
[0.8, 0.9]
|
||||
]
|
||||
const interpolate = createLinearInterpolator(points)
|
||||
expect(interpolate(0)).toBe(0.3)
|
||||
expect(interpolate(1)).toBe(0.9)
|
||||
})
|
||||
|
||||
it('handles unsorted input points', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[1, 1],
|
||||
[0, 0],
|
||||
[0.5, 0.5]
|
||||
]
|
||||
const interpolate = createLinearInterpolator(points)
|
||||
expect(interpolate(0)).toBe(0)
|
||||
expect(interpolate(0.5)).toBeCloseTo(0.5, 10)
|
||||
expect(interpolate(1)).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('curvesToLUT', () => {
|
||||
it('returns a 256-entry Uint8Array', () => {
|
||||
const lut = curvesToLUT([
|
||||
|
||||
@@ -1,19 +1,70 @@
|
||||
import type { CurvePoint } from './types'
|
||||
import { CURVE_INTERPOLATIONS } from './types'
|
||||
import type { CurveData, CurveInterpolation, CurvePoint } from './types'
|
||||
|
||||
export function isCurvePointArray(value: unknown): value is CurvePoint[] {
|
||||
export function isCurveData(value: unknown): value is CurveData {
|
||||
if (typeof value !== 'object' || value === null || Array.isArray(value))
|
||||
return false
|
||||
const v = value as Record<string, unknown>
|
||||
return (
|
||||
Array.isArray(value) &&
|
||||
value.length >= 2 &&
|
||||
value.every(
|
||||
(p) =>
|
||||
Array.isArray(v.points) &&
|
||||
v.points.every(
|
||||
(p: unknown) =>
|
||||
Array.isArray(p) &&
|
||||
p.length === 2 &&
|
||||
typeof p[0] === 'number' &&
|
||||
typeof p[1] === 'number'
|
||||
)
|
||||
) &&
|
||||
typeof v.interpolation === 'string' &&
|
||||
CURVE_INTERPOLATIONS.includes(v.interpolation as CurveInterpolation)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Piecewise linear interpolation through sorted control points.
|
||||
* Returns a function that evaluates y for any x in [0, 1].
|
||||
*/
|
||||
export function createLinearInterpolator(
|
||||
points: CurvePoint[]
|
||||
): (x: number) => number {
|
||||
if (points.length === 0) return () => 0
|
||||
if (points.length === 1) return () => points[0][1]
|
||||
|
||||
const sorted = [...points].sort((a, b) => a[0] - b[0])
|
||||
const n = sorted.length
|
||||
const xs = sorted.map((p) => p[0])
|
||||
const ys = sorted.map((p) => p[1])
|
||||
|
||||
return (x: number): number => {
|
||||
if (x <= xs[0]) return ys[0]
|
||||
if (x >= xs[n - 1]) return ys[n - 1]
|
||||
|
||||
let lo = 0
|
||||
let hi = n - 1
|
||||
while (lo < hi - 1) {
|
||||
const mid = (lo + hi) >> 1
|
||||
if (xs[mid] <= x) lo = mid
|
||||
else hi = mid
|
||||
}
|
||||
|
||||
const dx = xs[hi] - xs[lo]
|
||||
if (dx === 0) return ys[lo]
|
||||
const t = (x - xs[lo]) / dx
|
||||
return ys[lo] + t * (ys[hi] - ys[lo])
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory that dispatches to the correct interpolator based on type.
|
||||
*/
|
||||
export function createInterpolator(
|
||||
points: CurvePoint[],
|
||||
interpolation: CurveInterpolation
|
||||
): (x: number) => number {
|
||||
return interpolation === 'linear'
|
||||
? createLinearInterpolator(points)
|
||||
: createMonotoneInterpolator(points)
|
||||
}
|
||||
|
||||
/**
|
||||
* Monotone cubic Hermite interpolation.
|
||||
* Produces a smooth curve that passes through all control points
|
||||
@@ -120,9 +171,12 @@ export function histogramToPath(histogram: Uint32Array): string {
|
||||
return parts.join(' ')
|
||||
}
|
||||
|
||||
export function curvesToLUT(points: CurvePoint[]): Uint8Array {
|
||||
export function curvesToLUT(
|
||||
points: CurvePoint[],
|
||||
interpolation: CurveInterpolation = 'monotone_cubic'
|
||||
): Uint8Array {
|
||||
const lut = new Uint8Array(256)
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
const interpolate = createInterpolator(points, interpolation)
|
||||
|
||||
for (let i = 0; i < 256; i++) {
|
||||
const x = i / 255
|
||||
|
||||
@@ -1 +1,10 @@
|
||||
export type CurvePoint = [x: number, y: number]
|
||||
|
||||
export const CURVE_INTERPOLATIONS = ['monotone_cubic', 'linear'] as const
|
||||
|
||||
export type CurveInterpolation = (typeof CURVE_INTERPOLATIONS)[number]
|
||||
|
||||
export interface CurveData {
|
||||
points: CurvePoint[]
|
||||
interpolation: CurveInterpolation
|
||||
}
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
import { computed, onBeforeUnmount, ref } from 'vue'
|
||||
import type { Ref } from 'vue'
|
||||
|
||||
import { createMonotoneInterpolator } from '@/components/curve/curveUtils'
|
||||
import type { CurvePoint } from '@/components/curve/types'
|
||||
import { createInterpolator } from '@/components/curve/curveUtils'
|
||||
import type { CurveInterpolation, CurvePoint } from '@/components/curve/types'
|
||||
|
||||
interface UseCurveEditorOptions {
|
||||
svgRef: Ref<SVGSVGElement | null>
|
||||
modelValue: Ref<CurvePoint[]>
|
||||
interpolation: Ref<CurveInterpolation>
|
||||
}
|
||||
|
||||
export function useCurveEditor({ svgRef, modelValue }: UseCurveEditorOptions) {
|
||||
export function useCurveEditor({
|
||||
svgRef,
|
||||
modelValue,
|
||||
interpolation
|
||||
}: UseCurveEditorOptions) {
|
||||
const dragIndex = ref(-1)
|
||||
let cleanupDrag: (() => void) | null = null
|
||||
|
||||
const curvePath = computed(() => {
|
||||
const points = modelValue.value
|
||||
if (points.length < 2) return ''
|
||||
const sorted = [...points].sort((a, b) => a[0] - b[0])
|
||||
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
const xMin = points[0][0]
|
||||
const xMax = points[points.length - 1][0]
|
||||
if (interpolation.value === 'linear') {
|
||||
return sorted
|
||||
.map((p, i) => `${i === 0 ? 'M' : 'L'}${p[0]},${1 - p[1]}`)
|
||||
.join('')
|
||||
}
|
||||
|
||||
const interpolate = createInterpolator(sorted, interpolation.value)
|
||||
const xMin = sorted[0][0]
|
||||
const xMax = sorted[sorted.length - 1][0]
|
||||
const segments = 128
|
||||
const range = xMax - xMin
|
||||
const parts: string[] = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { Bounds } from '@/renderer/core/layout/types'
|
||||
import type { CurvePoint } from '@/components/curve/types'
|
||||
import type { CurveData } from '@/components/curve/types'
|
||||
|
||||
import type {
|
||||
CanvasColour,
|
||||
@@ -331,9 +331,9 @@ export interface IBoundingBoxWidget extends IBaseWidget<Bounds, 'boundingbox'> {
|
||||
value: Bounds
|
||||
}
|
||||
|
||||
export interface ICurveWidget extends IBaseWidget<CurvePoint[], 'curve'> {
|
||||
export interface ICurveWidget extends IBaseWidget<CurveData, 'curve'> {
|
||||
type: 'curve'
|
||||
value: CurvePoint[]
|
||||
value: CurveData
|
||||
}
|
||||
|
||||
export interface IPainterWidget extends IBaseWidget<string, 'painter'> {
|
||||
|
||||
@@ -1974,6 +1974,10 @@
|
||||
"width": "Width",
|
||||
"height": "Height"
|
||||
},
|
||||
"curveWidget": {
|
||||
"monotone_cubic": "Smooth",
|
||||
"linear": "Linear"
|
||||
},
|
||||
"toastMessages": {
|
||||
"nothingToQueue": "Nothing to queue",
|
||||
"pleaseSelectOutputNodes": "Please select output nodes",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { CurveData } from '@/components/curve/types'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ICurveWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import type {
|
||||
@@ -6,20 +7,22 @@ import type {
|
||||
} from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import type { ComfyWidgetConstructorV2 } from '@/scripts/widgets'
|
||||
|
||||
const DEFAULT_CURVE_DATA: CurveData = {
|
||||
points: [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
],
|
||||
interpolation: 'monotone_cubic'
|
||||
}
|
||||
|
||||
export const useCurveWidget = (): ComfyWidgetConstructorV2 => {
|
||||
return (node: LGraphNode, inputSpec: InputSpecV2): ICurveWidget => {
|
||||
const spec = inputSpec as CurveInputSpec
|
||||
const defaultValue = spec.default ?? [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
]
|
||||
const defaultValue: CurveData = spec.default
|
||||
? { ...spec.default, points: [...spec.default.points] }
|
||||
: { ...DEFAULT_CURVE_DATA, points: [...DEFAULT_CURVE_DATA.points] }
|
||||
|
||||
const rawWidget = node.addWidget(
|
||||
'curve',
|
||||
spec.name,
|
||||
[...defaultValue],
|
||||
() => {}
|
||||
)
|
||||
const rawWidget = node.addWidget('curve', spec.name, defaultValue, () => {})
|
||||
|
||||
if (rawWidget.type !== 'curve') {
|
||||
throw new Error(`Unexpected widget type: ${rawWidget.type}`)
|
||||
|
||||
@@ -128,11 +128,16 @@ const zTextareaInputSpec = zBaseInputOptions.extend({
|
||||
|
||||
const zCurvePoint = z.tuple([z.number(), z.number()])
|
||||
|
||||
const zCurveData = z.object({
|
||||
points: z.array(zCurvePoint),
|
||||
interpolation: z.enum(['monotone_cubic', 'linear'])
|
||||
})
|
||||
|
||||
const zCurveInputSpec = zBaseInputOptions.extend({
|
||||
type: z.literal('CURVE'),
|
||||
name: z.string(),
|
||||
isOptional: z.boolean().optional(),
|
||||
default: z.array(zCurvePoint).optional()
|
||||
default: zCurveData.optional()
|
||||
})
|
||||
|
||||
const zCustomInputSpec = zBaseInputOptions.extend({
|
||||
|
||||
Reference in New Issue
Block a user