mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-26 01:34:07 +00:00
feat: add CurveEditor component (#8860)
## Summary Prerequisite for upcoming native color correction nodes (ColorCurves). Reusable curve editor with monotone cubic Hermite interpolation, drag-to-add/move/delete control points, and SVG-based rendering. Includes CurvePoint type, LUT generation utility, and useCurveEditor composable for interaction logic. ## Screenshots (if applicable) https://github.com/user-attachments/assets/948352c7-bdf2-40f9-a8f0-35bc2b2f3202 ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-8860-feat-add-CurveEditor-component-and-d3-shape-dependency-3076d73d3650817f8421f98e349569d0) by [Unito](https://www.unito.io)
This commit is contained in:
113
src/components/curve/CurveEditor.test.ts
Normal file
113
src/components/curve/CurveEditor.test.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { mount } from '@vue/test-utils'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
import CurveEditor from './CurveEditor.vue'
|
||||
|
||||
function mountEditor(points: CurvePoint[], extraProps = {}) {
|
||||
return mount(CurveEditor, {
|
||||
props: { modelValue: points, ...extraProps }
|
||||
})
|
||||
}
|
||||
|
||||
function getCurvePath(wrapper: ReturnType<typeof mount>) {
|
||||
return wrapper.find('[data-testid="curve-path"]')
|
||||
}
|
||||
|
||||
describe('CurveEditor', () => {
|
||||
it('renders SVG with curve path', () => {
|
||||
const wrapper = mountEditor([
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
])
|
||||
expect(wrapper.find('svg').exists()).toBe(true)
|
||||
const curvePath = getCurvePath(wrapper)
|
||||
expect(curvePath.exists()).toBe(true)
|
||||
expect(curvePath.attributes('d')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders a circle for each control point', () => {
|
||||
const wrapper = mountEditor([
|
||||
[0, 0],
|
||||
[0.5, 0.7],
|
||||
[1, 1]
|
||||
])
|
||||
expect(wrapper.findAll('circle')).toHaveLength(3)
|
||||
})
|
||||
|
||||
it('renders histogram path when provided', () => {
|
||||
const histogram = new Uint32Array(256)
|
||||
for (let i = 0; i < 256; i++) histogram[i] = i + 1
|
||||
const wrapper = mountEditor(
|
||||
[
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
],
|
||||
{ histogram }
|
||||
)
|
||||
const histogramPath = wrapper.find('[data-testid="histogram-path"]')
|
||||
expect(histogramPath.exists()).toBe(true)
|
||||
expect(histogramPath.attributes('d')).toContain('M0,1')
|
||||
})
|
||||
|
||||
it('does not render histogram path when not provided', () => {
|
||||
const wrapper = mountEditor([
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
])
|
||||
expect(wrapper.find('[data-testid="histogram-path"]').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('returns empty path with fewer than 2 points', () => {
|
||||
const wrapper = mountEditor([[0.5, 0.5]])
|
||||
expect(getCurvePath(wrapper).attributes('d')).toBe('')
|
||||
})
|
||||
|
||||
it('generates path starting with M and containing L segments', () => {
|
||||
const wrapper = mountEditor([
|
||||
[0, 0],
|
||||
[0.5, 0.8],
|
||||
[1, 1]
|
||||
])
|
||||
const d = getCurvePath(wrapper).attributes('d')!
|
||||
expect(d).toMatch(/^M/)
|
||||
expect(d).toContain('L')
|
||||
})
|
||||
|
||||
it('curve path only spans the x-range of control points', () => {
|
||||
const wrapper = mountEditor([
|
||||
[0.2, 0.3],
|
||||
[0.8, 0.9]
|
||||
])
|
||||
const d = getCurvePath(wrapper).attributes('d')!
|
||||
const xValues = d
|
||||
.split(/[ML]/)
|
||||
.filter(Boolean)
|
||||
.map((s) => parseFloat(s.split(',')[0]))
|
||||
expect(Math.min(...xValues)).toBeCloseTo(0.2, 2)
|
||||
expect(Math.max(...xValues)).toBeCloseTo(0.8, 2)
|
||||
})
|
||||
|
||||
it('deletes a point on right-click but keeps minimum 2', async () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0, 0],
|
||||
[0.5, 0.5],
|
||||
[1, 1]
|
||||
]
|
||||
const wrapper = mountEditor(points)
|
||||
expect(wrapper.findAll('circle')).toHaveLength(3)
|
||||
|
||||
await wrapper.findAll('circle')[1].trigger('pointerdown', {
|
||||
button: 2,
|
||||
pointerId: 1
|
||||
})
|
||||
expect(wrapper.findAll('circle')).toHaveLength(2)
|
||||
|
||||
await wrapper.findAll('circle')[0].trigger('pointerdown', {
|
||||
button: 2,
|
||||
pointerId: 1
|
||||
})
|
||||
expect(wrapper.findAll('circle')).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
103
src/components/curve/CurveEditor.vue
Normal file
103
src/components/curve/CurveEditor.vue
Normal file
@@ -0,0 +1,103 @@
|
||||
<template>
|
||||
<svg
|
||||
ref="svgRef"
|
||||
viewBox="-0.04 -0.04 1.08 1.08"
|
||||
preserveAspectRatio="xMidYMid meet"
|
||||
class="aspect-square w-full cursor-crosshair rounded-[5px] bg-node-component-surface"
|
||||
@pointerdown.stop="handleSvgPointerDown"
|
||||
@contextmenu.prevent.stop
|
||||
>
|
||||
<line
|
||||
v-for="v in [0.25, 0.5, 0.75]"
|
||||
:key="'h' + v"
|
||||
:x1="0"
|
||||
:y1="v"
|
||||
:x2="1"
|
||||
:y2="v"
|
||||
stroke="currentColor"
|
||||
stroke-opacity="0.1"
|
||||
stroke-width="0.003"
|
||||
/>
|
||||
<line
|
||||
v-for="v in [0.25, 0.5, 0.75]"
|
||||
:key="'v' + v"
|
||||
:x1="v"
|
||||
:y1="0"
|
||||
:x2="v"
|
||||
:y2="1"
|
||||
stroke="currentColor"
|
||||
stroke-opacity="0.1"
|
||||
stroke-width="0.003"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1="0"
|
||||
y1="1"
|
||||
x2="1"
|
||||
y2="0"
|
||||
stroke="currentColor"
|
||||
stroke-opacity="0.15"
|
||||
stroke-width="0.003"
|
||||
/>
|
||||
|
||||
<path
|
||||
v-if="histogramPath"
|
||||
data-testid="histogram-path"
|
||||
:d="histogramPath"
|
||||
:fill="curveColor"
|
||||
fill-opacity="0.15"
|
||||
stroke="none"
|
||||
/>
|
||||
|
||||
<path
|
||||
data-testid="curve-path"
|
||||
:d="curvePath"
|
||||
fill="none"
|
||||
:stroke="curveColor"
|
||||
stroke-width="0.008"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
|
||||
<circle
|
||||
v-for="(point, i) in modelValue"
|
||||
:key="i"
|
||||
:cx="point[0]"
|
||||
:cy="1 - point[1]"
|
||||
r="0.02"
|
||||
:fill="curveColor"
|
||||
stroke="white"
|
||||
stroke-width="0.004"
|
||||
class="cursor-grab"
|
||||
@pointerdown.stop="startDrag(i, $event)"
|
||||
/>
|
||||
</svg>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, useTemplateRef } from 'vue'
|
||||
|
||||
import { useCurveEditor } from '@/composables/useCurveEditor'
|
||||
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
import { histogramToPath } from './curveUtils'
|
||||
|
||||
const { curveColor = 'white', histogram } = defineProps<{
|
||||
curveColor?: string
|
||||
histogram?: Uint32Array | null
|
||||
}>()
|
||||
|
||||
const modelValue = defineModel<CurvePoint[]>({
|
||||
required: true
|
||||
})
|
||||
|
||||
const svgRef = useTemplateRef<SVGSVGElement>('svgRef')
|
||||
|
||||
const { curvePath, handleSvgPointerDown, startDrag } = useCurveEditor({
|
||||
svgRef,
|
||||
modelValue
|
||||
})
|
||||
|
||||
const histogramPath = computed(() =>
|
||||
histogram ? histogramToPath(histogram) : ''
|
||||
)
|
||||
</script>
|
||||
16
src/components/curve/WidgetCurve.vue
Normal file
16
src/components/curve/WidgetCurve.vue
Normal file
@@ -0,0 +1,16 @@
|
||||
<template>
|
||||
<CurveEditor v-model="modelValue" />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
import CurveEditor from './CurveEditor.vue'
|
||||
|
||||
const modelValue = defineModel<CurvePoint[]>({
|
||||
default: () => [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
]
|
||||
})
|
||||
</script>
|
||||
141
src/components/curve/curveUtils.test.ts
Normal file
141
src/components/curve/curveUtils.test.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
import {
|
||||
createMonotoneInterpolator,
|
||||
curvesToLUT,
|
||||
histogramToPath
|
||||
} from './curveUtils'
|
||||
|
||||
describe('createMonotoneInterpolator', () => {
|
||||
it('returns 0 for empty points', () => {
|
||||
const interpolate = createMonotoneInterpolator([])
|
||||
expect(interpolate(0.5)).toBe(0)
|
||||
})
|
||||
|
||||
it('returns constant for single point', () => {
|
||||
const interpolate = createMonotoneInterpolator([[0.5, 0.7]])
|
||||
expect(interpolate(0)).toBe(0.7)
|
||||
expect(interpolate(1)).toBe(0.7)
|
||||
})
|
||||
|
||||
it('passes through control points exactly', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0, 0],
|
||||
[0.5, 0.8],
|
||||
[1, 1]
|
||||
]
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
expect(interpolate(0)).toBeCloseTo(0, 5)
|
||||
expect(interpolate(0.5)).toBeCloseTo(0.8, 5)
|
||||
expect(interpolate(1)).toBeCloseTo(1, 5)
|
||||
})
|
||||
|
||||
it('clamps to endpoint values outside range', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0.2, 0.3],
|
||||
[0.8, 0.9]
|
||||
]
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
expect(interpolate(0)).toBe(0.3)
|
||||
expect(interpolate(1)).toBe(0.9)
|
||||
})
|
||||
|
||||
it('produces monotone output for monotone input', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[0, 0],
|
||||
[0.25, 0.2],
|
||||
[0.5, 0.5],
|
||||
[0.75, 0.8],
|
||||
[1, 1]
|
||||
]
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
|
||||
let prev = -Infinity
|
||||
for (let x = 0; x <= 1; x += 0.01) {
|
||||
const y = interpolate(x)
|
||||
expect(y).toBeGreaterThanOrEqual(prev)
|
||||
prev = y
|
||||
}
|
||||
})
|
||||
|
||||
it('handles unsorted input points', () => {
|
||||
const points: CurvePoint[] = [
|
||||
[1, 1],
|
||||
[0, 0],
|
||||
[0.5, 0.5]
|
||||
]
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
expect(interpolate(0)).toBeCloseTo(0, 5)
|
||||
expect(interpolate(0.5)).toBeCloseTo(0.5, 5)
|
||||
expect(interpolate(1)).toBeCloseTo(1, 5)
|
||||
})
|
||||
})
|
||||
|
||||
describe('curvesToLUT', () => {
|
||||
it('returns a 256-entry Uint8Array', () => {
|
||||
const lut = curvesToLUT([
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
])
|
||||
expect(lut).toBeInstanceOf(Uint8Array)
|
||||
expect(lut.length).toBe(256)
|
||||
})
|
||||
|
||||
it('produces identity LUT for diagonal curve', () => {
|
||||
const lut = curvesToLUT([
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
])
|
||||
for (let i = 0; i < 256; i++) {
|
||||
expect(lut[i]).toBeCloseTo(i, 0)
|
||||
}
|
||||
})
|
||||
|
||||
it('clamps output to [0, 255]', () => {
|
||||
const lut = curvesToLUT([
|
||||
[0, 0],
|
||||
[0.5, 1.5],
|
||||
[1, 1]
|
||||
])
|
||||
for (let i = 0; i < 256; i++) {
|
||||
expect(lut[i]).toBeGreaterThanOrEqual(0)
|
||||
expect(lut[i]).toBeLessThanOrEqual(255)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('histogramToPath', () => {
|
||||
it('returns empty string for empty histogram', () => {
|
||||
expect(histogramToPath(new Uint32Array(0))).toBe('')
|
||||
})
|
||||
|
||||
it('returns empty string when all bins are zero', () => {
|
||||
expect(histogramToPath(new Uint32Array(256))).toBe('')
|
||||
})
|
||||
|
||||
it('returns a closed SVG path for valid histogram', () => {
|
||||
const histogram = new Uint32Array(256)
|
||||
for (let i = 0; i < 256; i++) histogram[i] = i + 1
|
||||
const path = histogramToPath(histogram)
|
||||
expect(path).toMatch(/^M0,1/)
|
||||
expect(path).toMatch(/L1,1 Z$/)
|
||||
})
|
||||
|
||||
it('normalizes using 99.5th percentile to suppress outliers', () => {
|
||||
const histogram = new Uint32Array(256)
|
||||
for (let i = 0; i < 256; i++) histogram[i] = 100
|
||||
histogram[255] = 100000
|
||||
const path = histogramToPath(histogram)
|
||||
// Most bins should map to y=0 (1 - 100/100 = 0) since
|
||||
// the 99.5th percentile is 100, not the outlier 100000
|
||||
const yValues = path
|
||||
.split(/[ML]/)
|
||||
.filter(Boolean)
|
||||
.map((s) => parseFloat(s.split(',')[1]))
|
||||
.filter((y) => !isNaN(y))
|
||||
const nearZero = yValues.filter((y) => Math.abs(y) < 0.01)
|
||||
expect(nearZero.length).toBeGreaterThan(200)
|
||||
})
|
||||
})
|
||||
120
src/components/curve/curveUtils.ts
Normal file
120
src/components/curve/curveUtils.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
/**
|
||||
* Monotone cubic Hermite interpolation.
|
||||
* Produces a smooth curve that passes through all control points
|
||||
* without overshooting (monotone property).
|
||||
*
|
||||
* Returns a function that evaluates y for any x in [0, 1].
|
||||
*/
|
||||
export function createMonotoneInterpolator(
|
||||
points: CurvePoint[]
|
||||
): (x: number) => number {
|
||||
if (points.length === 0) return () => 0
|
||||
if (points.length === 1) return () => points[0][1]
|
||||
|
||||
const sorted = [...points].sort((a, b) => a[0] - b[0])
|
||||
const n = sorted.length
|
||||
const xs = sorted.map((p) => p[0])
|
||||
const ys = sorted.map((p) => p[1])
|
||||
|
||||
const deltas: number[] = []
|
||||
const slopes: number[] = []
|
||||
for (let i = 0; i < n - 1; i++) {
|
||||
const dx = xs[i + 1] - xs[i]
|
||||
deltas.push(dx === 0 ? 0 : (ys[i + 1] - ys[i]) / dx)
|
||||
}
|
||||
|
||||
slopes.push(deltas[0] ?? 0)
|
||||
for (let i = 1; i < n - 1; i++) {
|
||||
if (deltas[i - 1] * deltas[i] <= 0) {
|
||||
slopes.push(0)
|
||||
} else {
|
||||
slopes.push((deltas[i - 1] + deltas[i]) / 2)
|
||||
}
|
||||
}
|
||||
slopes.push(deltas[n - 2] ?? 0)
|
||||
|
||||
for (let i = 0; i < n - 1; i++) {
|
||||
if (deltas[i] === 0) {
|
||||
slopes[i] = 0
|
||||
slopes[i + 1] = 0
|
||||
} else {
|
||||
const alpha = slopes[i] / deltas[i]
|
||||
const beta = slopes[i + 1] / deltas[i]
|
||||
const s = alpha * alpha + beta * beta
|
||||
if (s > 9) {
|
||||
const t = 3 / Math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (x: number): number => {
|
||||
if (x <= xs[0]) return ys[0]
|
||||
if (x >= xs[n - 1]) return ys[n - 1]
|
||||
|
||||
let lo = 0
|
||||
let hi = n - 1
|
||||
while (lo < hi - 1) {
|
||||
const mid = (lo + hi) >> 1
|
||||
if (xs[mid] <= x) lo = mid
|
||||
else hi = mid
|
||||
}
|
||||
|
||||
const dx = xs[hi] - xs[lo]
|
||||
if (dx === 0) return ys[lo]
|
||||
|
||||
const t = (x - xs[lo]) / dx
|
||||
const t2 = t * t
|
||||
const t3 = t2 * t
|
||||
|
||||
const h00 = 2 * t3 - 3 * t2 + 1
|
||||
const h10 = t3 - 2 * t2 + t
|
||||
const h01 = -2 * t3 + 3 * t2
|
||||
const h11 = t3 - t2
|
||||
|
||||
return (
|
||||
h00 * ys[lo] +
|
||||
h10 * dx * slopes[lo] +
|
||||
h01 * ys[hi] +
|
||||
h11 * dx * slopes[hi]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a 256-bin histogram into an SVG path string.
|
||||
* Normalizes using the 99.5th percentile to avoid outlier spikes.
|
||||
*/
|
||||
export function histogramToPath(histogram: Uint32Array): string {
|
||||
if (!histogram.length) return ''
|
||||
|
||||
const sorted = Array.from(histogram).sort((a, b) => a - b)
|
||||
const max = sorted[Math.floor(255 * 0.995)]
|
||||
if (max === 0) return ''
|
||||
|
||||
const step = 1 / 255
|
||||
let d = 'M0,1'
|
||||
for (let i = 0; i < 256; i++) {
|
||||
const x = i * step
|
||||
const y = 1 - Math.min(1, histogram[i] / max)
|
||||
d += ` L${x.toFixed(4)},${y.toFixed(4)}`
|
||||
}
|
||||
d += ' L1,1 Z'
|
||||
return d
|
||||
}
|
||||
|
||||
export function curvesToLUT(points: CurvePoint[]): Uint8Array {
|
||||
const lut = new Uint8Array(256)
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
|
||||
for (let i = 0; i < 256; i++) {
|
||||
const x = i / 255
|
||||
const y = interpolate(x)
|
||||
lut[i] = Math.max(0, Math.min(255, Math.round(y * 255)))
|
||||
}
|
||||
|
||||
return lut
|
||||
}
|
||||
140
src/composables/useCurveEditor.ts
Normal file
140
src/composables/useCurveEditor.ts
Normal file
@@ -0,0 +1,140 @@
|
||||
import { computed, onBeforeUnmount, ref } from 'vue'
|
||||
import type { Ref } from 'vue'
|
||||
|
||||
import { createMonotoneInterpolator } from '@/components/curve/curveUtils'
|
||||
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
interface UseCurveEditorOptions {
|
||||
svgRef: Ref<SVGSVGElement | null>
|
||||
modelValue: Ref<CurvePoint[]>
|
||||
}
|
||||
|
||||
export function useCurveEditor({ svgRef, modelValue }: UseCurveEditorOptions) {
|
||||
const dragIndex = ref(-1)
|
||||
let cleanupDrag: (() => void) | null = null
|
||||
|
||||
const curvePath = computed(() => {
|
||||
const points = modelValue.value
|
||||
if (points.length < 2) return ''
|
||||
|
||||
const interpolate = createMonotoneInterpolator(points)
|
||||
const xMin = points[0][0]
|
||||
const xMax = points[points.length - 1][0]
|
||||
const segments = 128
|
||||
const parts: string[] = []
|
||||
for (let i = 0; i <= segments; i++) {
|
||||
const x = xMin + (xMax - xMin) * (i / segments)
|
||||
const y = 1 - interpolate(x)
|
||||
parts.push(`${i === 0 ? 'M' : 'L'}${x.toFixed(4)},${y.toFixed(4)}`)
|
||||
}
|
||||
return parts.join('')
|
||||
})
|
||||
|
||||
function svgCoords(e: PointerEvent): [number, number] {
|
||||
const svg = svgRef.value
|
||||
if (!svg) return [0, 0]
|
||||
|
||||
const ctm = svg.getScreenCTM()
|
||||
if (!ctm) return [0, 0]
|
||||
|
||||
const svgPt = new DOMPoint(e.clientX, e.clientY).matrixTransform(
|
||||
ctm.inverse()
|
||||
)
|
||||
return [
|
||||
Math.max(0, Math.min(1, svgPt.x)),
|
||||
Math.max(0, Math.min(1, 1 - svgPt.y))
|
||||
]
|
||||
}
|
||||
|
||||
function findNearestPoint(x: number, y: number): number {
|
||||
const threshold2 = 0.04 * 0.04
|
||||
let nearest = -1
|
||||
let minDist2 = threshold2
|
||||
for (let i = 0; i < modelValue.value.length; i++) {
|
||||
const dx = modelValue.value[i][0] - x
|
||||
const dy = modelValue.value[i][1] - y
|
||||
const dist2 = dx * dx + dy * dy
|
||||
if (dist2 < minDist2) {
|
||||
minDist2 = dist2
|
||||
nearest = i
|
||||
}
|
||||
}
|
||||
return nearest
|
||||
}
|
||||
|
||||
function handleSvgPointerDown(e: PointerEvent) {
|
||||
if (e.button !== 0) return
|
||||
|
||||
const [x, y] = svgCoords(e)
|
||||
|
||||
const nearby = findNearestPoint(x, y)
|
||||
if (nearby >= 0) {
|
||||
startDrag(nearby, e)
|
||||
return
|
||||
}
|
||||
|
||||
if (e.ctrlKey) return
|
||||
|
||||
const newPoint: CurvePoint = [x, y]
|
||||
const newPoints: CurvePoint[] = [...modelValue.value, newPoint]
|
||||
newPoints.sort((a, b) => a[0] - b[0])
|
||||
modelValue.value = newPoints
|
||||
|
||||
startDrag(newPoints.indexOf(newPoint), e)
|
||||
}
|
||||
|
||||
function startDrag(index: number, e: PointerEvent) {
|
||||
cleanupDrag?.()
|
||||
|
||||
if (e.button === 2 || (e.button === 0 && e.ctrlKey)) {
|
||||
if (modelValue.value.length > 2) {
|
||||
const newPoints = [...modelValue.value]
|
||||
newPoints.splice(index, 1)
|
||||
modelValue.value = newPoints
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
dragIndex.value = index
|
||||
const svg = svgRef.value
|
||||
if (!svg) return
|
||||
|
||||
svg.setPointerCapture(e.pointerId)
|
||||
|
||||
const onMove = (ev: PointerEvent) => {
|
||||
if (dragIndex.value < 0) return
|
||||
const [x, y] = svgCoords(ev)
|
||||
const movedPoint: CurvePoint = [x, y]
|
||||
const newPoints = [...modelValue.value]
|
||||
newPoints[dragIndex.value] = movedPoint
|
||||
newPoints.sort((a, b) => a[0] - b[0])
|
||||
modelValue.value = newPoints
|
||||
dragIndex.value = newPoints.indexOf(movedPoint)
|
||||
}
|
||||
|
||||
const endDrag = () => {
|
||||
if (dragIndex.value < 0) return
|
||||
dragIndex.value = -1
|
||||
svg.removeEventListener('pointermove', onMove)
|
||||
svg.removeEventListener('pointerup', endDrag)
|
||||
svg.removeEventListener('lostpointercapture', endDrag)
|
||||
cleanupDrag = null
|
||||
}
|
||||
|
||||
cleanupDrag = endDrag
|
||||
|
||||
svg.addEventListener('pointermove', onMove)
|
||||
svg.addEventListener('pointerup', endDrag)
|
||||
svg.addEventListener('lostpointercapture', endDrag)
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
cleanupDrag?.()
|
||||
})
|
||||
|
||||
return {
|
||||
curvePath,
|
||||
handleSvgPointerDown,
|
||||
startDrag
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,7 @@ export type IWidget =
|
||||
| IAssetWidget
|
||||
| IImageCropWidget
|
||||
| IBoundingBoxWidget
|
||||
| ICurveWidget
|
||||
|
||||
export interface IBooleanWidget extends IBaseWidget<boolean, 'toggle'> {
|
||||
type: 'toggle'
|
||||
@@ -328,6 +329,13 @@ export interface IBoundingBoxWidget extends IBaseWidget<Bounds, 'boundingbox'> {
|
||||
value: Bounds
|
||||
}
|
||||
|
||||
export type CurvePoint = [x: number, y: number]
|
||||
|
||||
export interface ICurveWidget extends IBaseWidget<CurvePoint[], 'curve'> {
|
||||
type: 'curve'
|
||||
value: CurvePoint[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid widget types. TS cannot provide easily extensible type safety for this at present.
|
||||
* Override linkedWidgets[]
|
||||
|
||||
16
src/lib/litegraph/src/widgets/CurveWidget.ts
Normal file
16
src/lib/litegraph/src/widgets/CurveWidget.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { ICurveWidget } from '../types/widgets'
|
||||
import { BaseWidget } from './BaseWidget'
|
||||
import type { DrawWidgetOptions, WidgetEventOptions } from './BaseWidget'
|
||||
|
||||
export class CurveWidget
|
||||
extends BaseWidget<ICurveWidget>
|
||||
implements ICurveWidget
|
||||
{
|
||||
override type = 'curve' as const
|
||||
|
||||
drawWidget(ctx: CanvasRenderingContext2D, options: DrawWidgetOptions): void {
|
||||
this.drawVueOnlyWarning(ctx, options, 'Curve')
|
||||
}
|
||||
|
||||
onClick(_options: WidgetEventOptions): void {}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import { ButtonWidget } from './ButtonWidget'
|
||||
import { ChartWidget } from './ChartWidget'
|
||||
import { ColorWidget } from './ColorWidget'
|
||||
import { ComboWidget } from './ComboWidget'
|
||||
import { CurveWidget } from './CurveWidget'
|
||||
import { FileUploadWidget } from './FileUploadWidget'
|
||||
import { GalleriaWidget } from './GalleriaWidget'
|
||||
import { GradientSliderWidget } from './GradientSliderWidget'
|
||||
@@ -56,6 +57,7 @@ export type WidgetTypeMap = {
|
||||
asset: AssetWidget
|
||||
imagecrop: ImageCropWidget
|
||||
boundingbox: BoundingBoxWidget
|
||||
curve: CurveWidget
|
||||
[key: string]: BaseWidget
|
||||
}
|
||||
|
||||
@@ -132,6 +134,8 @@ export function toConcreteWidget<TWidget extends IWidget | IBaseWidget>(
|
||||
return toClass(ImageCropWidget, narrowedWidget, node)
|
||||
case 'boundingbox':
|
||||
return toClass(BoundingBoxWidget, narrowedWidget, node)
|
||||
case 'curve':
|
||||
return toClass(CurveWidget, narrowedWidget, node)
|
||||
default: {
|
||||
if (wrapLegacyWidgets) return toClass(LegacyWidget, widget, node)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ICurveWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import type {
|
||||
CurveInputSpec,
|
||||
InputSpec as InputSpecV2
|
||||
} from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import type { ComfyWidgetConstructorV2 } from '@/scripts/widgets'
|
||||
|
||||
export const useCurveWidget = (): ComfyWidgetConstructorV2 => {
|
||||
return (node: LGraphNode, inputSpec: InputSpecV2): ICurveWidget => {
|
||||
const spec = inputSpec as CurveInputSpec
|
||||
const defaultValue = spec.default ?? [
|
||||
[0, 0],
|
||||
[1, 1]
|
||||
]
|
||||
|
||||
const rawWidget = node.addWidget(
|
||||
'curve',
|
||||
spec.name,
|
||||
[...defaultValue],
|
||||
() => {}
|
||||
)
|
||||
|
||||
if (rawWidget.type !== 'curve') {
|
||||
throw new Error(`Unexpected widget type: ${rawWidget.type}`)
|
||||
}
|
||||
|
||||
return rawWidget as ICurveWidget
|
||||
}
|
||||
}
|
||||
@@ -57,6 +57,9 @@ const WidgetImageCrop = defineAsyncComponent(
|
||||
const WidgetBoundingBox = defineAsyncComponent(
|
||||
() => import('@/components/boundingbox/WidgetBoundingBox.vue')
|
||||
)
|
||||
const WidgetCurve = defineAsyncComponent(
|
||||
() => import('@/components/curve/WidgetCurve.vue')
|
||||
)
|
||||
|
||||
export const FOR_TESTING = {
|
||||
WidgetButton,
|
||||
@@ -175,6 +178,14 @@ const coreWidgetDefinitions: Array<[string, WidgetDefinition]> = [
|
||||
aliases: ['BOUNDING_BOX'],
|
||||
essential: false
|
||||
}
|
||||
],
|
||||
[
|
||||
'curve',
|
||||
{
|
||||
component: WidgetCurve,
|
||||
aliases: ['CURVE'],
|
||||
essential: false
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
@@ -206,7 +217,7 @@ export const shouldRenderAsVue = (widget: Partial<SafeWidgetData>): boolean => {
|
||||
return !widget.options?.canvasOnly && !!widget.type
|
||||
}
|
||||
|
||||
const EXPANDING_TYPES = ['textarea', 'markdown', 'load3D'] as const
|
||||
const EXPANDING_TYPES = ['textarea', 'markdown', 'load3D', 'curve'] as const
|
||||
|
||||
export function shouldExpand(type: string): boolean {
|
||||
const canonicalType = getCanonicalType(type)
|
||||
|
||||
@@ -126,6 +126,15 @@ const zTextareaInputSpec = zBaseInputOptions.extend({
|
||||
.optional()
|
||||
})
|
||||
|
||||
const zCurvePoint = z.tuple([z.number(), z.number()])
|
||||
|
||||
const zCurveInputSpec = zBaseInputOptions.extend({
|
||||
type: z.literal('CURVE'),
|
||||
name: z.string(),
|
||||
isOptional: z.boolean().optional(),
|
||||
default: z.array(zCurvePoint).optional()
|
||||
})
|
||||
|
||||
const zCustomInputSpec = zBaseInputOptions.extend({
|
||||
type: z.string(),
|
||||
name: z.string(),
|
||||
@@ -146,6 +155,7 @@ const zInputSpec = z.union([
|
||||
zChartInputSpec,
|
||||
zGalleriaInputSpec,
|
||||
zTextareaInputSpec,
|
||||
zCurveInputSpec,
|
||||
zCustomInputSpec
|
||||
])
|
||||
|
||||
@@ -190,6 +200,7 @@ export type BoundingBoxInputSpec = z.infer<typeof zBoundingBoxInputSpec>
|
||||
export type ChartInputSpec = z.infer<typeof zChartInputSpec>
|
||||
export type GalleriaInputSpec = z.infer<typeof zGalleriaInputSpec>
|
||||
export type TextareaInputSpec = z.infer<typeof zTextareaInputSpec>
|
||||
export type CurveInputSpec = z.infer<typeof zCurveInputSpec>
|
||||
export type CustomInputSpec = z.infer<typeof zCustomInputSpec>
|
||||
|
||||
export type InputSpec = z.infer<typeof zInputSpec>
|
||||
|
||||
@@ -9,6 +9,7 @@ import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { dynamicWidgets } from '@/core/graph/widgets/dynamicWidgets'
|
||||
import { useBooleanWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget'
|
||||
import { useBoundingBoxWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget'
|
||||
import { useCurveWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useCurveWidget'
|
||||
import { useChartWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useChartWidget'
|
||||
import { useColorWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useColorWidget'
|
||||
import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget'
|
||||
@@ -306,6 +307,7 @@ export const ComfyWidgets = {
|
||||
CHART: transformWidgetConstructorV2ToV1(useChartWidget()),
|
||||
GALLERIA: transformWidgetConstructorV2ToV1(useGalleriaWidget()),
|
||||
TEXTAREA: transformWidgetConstructorV2ToV1(useTextareaWidget()),
|
||||
CURVE: transformWidgetConstructorV2ToV1(useCurveWidget()),
|
||||
...dynamicWidgets
|
||||
} as const
|
||||
|
||||
|
||||
Reference in New Issue
Block a user