Compare commits

...

1 Commits

Author SHA1 Message Date
Terry Jia
176467594d feat: add native Color Correct node and curve node 2026-02-14 16:13:35 -05:00
34 changed files with 2682 additions and 2 deletions

View File

@@ -86,6 +86,7 @@
"axios": "catalog:",
"chart.js": "^4.5.0",
"cva": "catalog:",
"d3-shape": "catalog:",
"dompurify": "^3.2.5",
"dotenv": "catalog:",
"es-toolkit": "^1.39.9",
@@ -131,6 +132,7 @@
"@storybook/vue3": "catalog:",
"@storybook/vue3-vite": "catalog:",
"@tailwindcss/vite": "catalog:",
"@types/d3-shape": "catalog:",
"@types/fs-extra": "catalog:",
"@types/jsdom": "catalog:",
"@types/node": "catalog:",

40
pnpm-lock.yaml generated
View File

@@ -93,6 +93,9 @@ catalogs:
'@tailwindcss/vite':
specifier: ^4.1.12
version: 4.1.12
'@types/d3-shape':
specifier: ^3.1.8
version: 3.1.8
'@types/fs-extra':
specifier: ^11.0.4
version: 11.0.4
@@ -141,6 +144,9 @@ catalogs:
cva:
specifier: 1.0.0-beta.4
version: 1.0.0-beta.4
d3-shape:
specifier: ^3.2.0
version: 3.2.0
dotenv:
specifier: ^16.4.5
version: 16.6.1
@@ -428,6 +434,9 @@ importers:
cva:
specifier: 'catalog:'
version: 1.0.0-beta.4(typescript@5.9.3)
d3-shape:
specifier: 'catalog:'
version: 3.2.0
dompurify:
specifier: ^3.2.5
version: 3.2.5
@@ -558,6 +567,9 @@ importers:
'@tailwindcss/vite':
specifier: 'catalog:'
version: 4.1.12(vite@8.0.0-beta.13(@types/node@24.10.4)(esbuild@0.27.1)(jiti@2.6.1)(terser@5.39.2)(tsx@4.19.4)(yaml@2.8.2))
'@types/d3-shape':
specifier: 'catalog:'
version: 3.1.8
'@types/fs-extra':
specifier: 'catalog:'
version: 11.0.4
@@ -2803,7 +2815,7 @@ packages:
'@primevue/themes@4.2.5':
resolution: {integrity: sha512-8F7yA36xYIKtNuAuyBdZZEks/bKDwlhH5WjpqGGB0FdwfAEoBYsynQ5sdqcT2Lb/NsajHmS5lc++Ttlvr1g1Lw==}
engines: {node: '>=12.11.0'}
deprecated: 'Deprecated. This package is no longer maintained. Please migrate to @primeuix/themes: https://www.npmjs.com/package/@primeuix/themes'
deprecated: This package is deprecated. Use @primeuix/themes instead.
'@protobufjs/aspromise@1.1.2':
resolution: {integrity: sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==}
@@ -3538,6 +3550,12 @@ packages:
'@types/css-font-loading-module@0.0.7':
resolution: {integrity: sha512-nl09VhutdjINdWyXxHWN/w9zlNCfr60JUqJbd24YXUuCwgeL0TpFSdElCwb6cxfB6ybE19Gjj4g0jsgkXxKv1Q==}
'@types/d3-path@3.1.1':
resolution: {integrity: sha512-VMZBYyQvbGmWyWVea0EHs/BwLgxc+MKi1zLDCONksozI4YJMcTt8ZEuIR4Sb1MMTE8MMW49v0IwI5+b7RmfWlg==}
'@types/d3-shape@3.1.8':
resolution: {integrity: sha512-lae0iWfcDeR7qt7rA88BNiqdvPS5pFVPpo5OfjElwNaT2yyekbM0C9vK+yqBqEmHr6lDkRnYNoTBYlAgJa7a4w==}
'@types/debug@4.1.12':
resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==}
@@ -4772,6 +4790,14 @@ packages:
typescript:
optional: true
d3-path@3.1.0:
resolution: {integrity: sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==}
engines: {node: '>=12'}
d3-shape@3.2.0:
resolution: {integrity: sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==}
engines: {node: '>=12'}
data-urls@6.0.0:
resolution: {integrity: sha512-BnBS08aLUM+DKamupXs3w2tJJoqU+AkaE/+6vQxi/G/DPmIZFJJp9Dkb1kM03AZx8ADehDUZgsNxju3mPXZYIA==}
engines: {node: '>=20'}
@@ -11561,6 +11587,12 @@ snapshots:
'@types/css-font-loading-module@0.0.7': {}
'@types/d3-path@3.1.1': {}
'@types/d3-shape@3.1.8':
dependencies:
'@types/d3-path': 3.1.1
'@types/debug@4.1.12':
dependencies:
'@types/ms': 2.1.0
@@ -12988,6 +13020,12 @@ snapshots:
optionalDependencies:
typescript: 5.9.3
d3-path@3.1.0: {}
d3-shape@3.2.0:
dependencies:
d3-path: 3.1.0
data-urls@6.0.0:
dependencies:
whatwg-mimetype: 4.0.0

View File

@@ -32,6 +32,7 @@ catalog:
'@storybook/vue3': ^10.1.9
'@storybook/vue3-vite': ^10.1.9
'@tailwindcss/vite': ^4.1.12
'@types/d3-shape': ^3.1.8
'@types/fs-extra': ^11.0.4
'@types/jsdom': ^21.1.7
'@types/node': ^24.1.0
@@ -48,6 +49,7 @@ catalog:
axios: ^1.8.2
cross-env: ^10.1.0
cva: 1.0.0-beta.4
d3-shape: ^3.2.0
dotenv: ^16.4.5
eslint: ^9.39.1
eslint-config-prettier: ^10.1.8

View File

@@ -0,0 +1,220 @@
<template>
<div
class="widget-expands relative flex h-full w-full flex-col gap-1"
@pointerdown.stop
@pointermove.stop
@pointerup.stop
>
<div
class="relative min-h-0 flex-1 overflow-hidden rounded-[5px] bg-node-component-surface"
>
<div
v-if="!imageUrl"
class="flex size-full flex-col items-center justify-center text-center"
>
<i class="mb-2 icon-[lucide--image] h-12 w-12" />
<p class="text-sm">{{ $t('colorBalance.noInputImage') }}</p>
</div>
<canvas
v-show="imageUrl"
ref="glCanvas"
class="block size-full select-none object-contain"
/>
</div>
<div
class="flex h-8 shrink-0 items-center gap-1 rounded-sm bg-component-node-widget-background p-1"
>
<Button
v-for="tab in TABS"
:key="tab.value"
variant="textonly"
size="unset"
:class="
cn(
'flex-1 self-stretch px-2 text-xs transition-colors',
activeTab === tab.value
? 'rounded-sm bg-component-node-widget-background-selected text-base-foreground'
: 'text-node-text-muted hover:text-node-text'
)
"
@click="activeTab = tab.value"
>
{{ $t(tab.label) }}
</Button>
</div>
<template v-for="tab in TABS" :key="tab.value">
<div
v-show="activeTab === tab.value"
class="grid shrink-0 grid-cols-[auto_1fr_auto] gap-x-2 gap-y-1"
>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorBalance.cyanRed') }}
</label>
<GradientSlider
v-model="fields[tab.value].red.value"
:min="-100"
:max="100"
:step="1"
:stops="CYAN_RED_STOPS"
class="h-7"
/>
<input
v-model.number="fields[tab.value].red.value"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorBalance.magentaGreen') }}
</label>
<GradientSlider
v-model="fields[tab.value].green.value"
:min="-100"
:max="100"
:step="1"
:stops="MAGENTA_GREEN_STOPS"
class="h-7"
/>
<input
v-model.number="fields[tab.value].green.value"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorBalance.yellowBlue') }}
</label>
<GradientSlider
v-model="fields[tab.value].blue.value"
:min="-100"
:max="100"
:step="1"
:stops="YELLOW_BLUE_STOPS"
class="h-7"
/>
<input
v-model.number="fields[tab.value].blue.value"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
</div>
</template>
<Button
variant="secondary"
size="md"
class="shrink-0 gap-2 rounded-lg border border-component-node-border bg-component-node-background text-xs text-muted-foreground hover:text-base-foreground"
@click="resetAll"
>
<i class="icon-[lucide--undo-2]" />
{{ $t('colorBalance.reset') }}
</Button>
</div>
</template>
<script setup lang="ts">
import { computed, ref, shallowRef } from 'vue'
import GradientSlider from '@/components/colorcorrect/GradientSlider.vue'
import {
CYAN_RED_STOPS,
MAGENTA_GREEN_STOPS,
YELLOW_BLUE_STOPS
} from '@/components/colorbalance/gradients'
import Button from '@/components/ui/button/Button.vue'
import { useColorBalance } from '@/composables/useColorBalance'
import { useWebGLColorBalance } from '@/composables/useWebGLColorBalance'
import type { ColorBalanceSettings } from '@/lib/litegraph/src/types/widgets'
import type { NodeId } from '@/platform/workflow/validation/schemas/workflowSchema'
import { cn } from '@/utils/tailwindUtil'
const props = defineProps<{
nodeId: NodeId
}>()
const modelValue = defineModel<ColorBalanceSettings>({
default: () => ({
shadows_red: 0,
shadows_green: 0,
shadows_blue: 0,
midtones_red: 0,
midtones_green: 0,
midtones_blue: 0,
highlights_red: 0,
highlights_green: 0,
highlights_blue: 0
})
})
const TABS = [
{ value: 'shadows' as const, label: 'colorBalance.shadows' },
{ value: 'midtones' as const, label: 'colorBalance.midtones' },
{ value: 'highlights' as const, label: 'colorBalance.highlights' }
]
type TonalRange = 'shadows' | 'midtones' | 'highlights'
const activeTab = ref<string>('shadows')
function rangeFieldAccessor(
range: TonalRange,
channel: 'red' | 'green' | 'blue'
) {
const key = `${range}_${channel}` as keyof ColorBalanceSettings
return computed({
get: () => modelValue.value[key],
set: (v) => {
modelValue.value = { ...modelValue.value, [key]: v }
}
})
}
const fields = {
shadows: {
red: rangeFieldAccessor('shadows', 'red'),
green: rangeFieldAccessor('shadows', 'green'),
blue: rangeFieldAccessor('shadows', 'blue')
},
midtones: {
red: rangeFieldAccessor('midtones', 'red'),
green: rangeFieldAccessor('midtones', 'green'),
blue: rangeFieldAccessor('midtones', 'blue')
},
highlights: {
red: rangeFieldAccessor('highlights', 'red'),
green: rangeFieldAccessor('highlights', 'green'),
blue: rangeFieldAccessor('highlights', 'blue')
}
}
const glCanvas = shallowRef<HTMLCanvasElement | null>(null)
const { imageUrl } = useColorBalance(props.nodeId)
useWebGLColorBalance(glCanvas, imageUrl, modelValue)
function resetAll() {
modelValue.value = {
shadows_red: 0,
shadows_green: 0,
shadows_blue: 0,
midtones_red: 0,
midtones_green: 0,
midtones_blue: 0,
highlights_red: 0,
highlights_green: 0,
highlights_blue: 0
}
}
</script>

View File

@@ -0,0 +1,16 @@
import type { ColorStop } from '@/components/colorcorrect/gradients'
export const CYAN_RED_STOPS: ColorStop[] = [
[0, 0, 255, 255],
[1, 255, 0, 0]
]
export const MAGENTA_GREEN_STOPS: ColorStop[] = [
[0, 255, 0, 255],
[1, 0, 255, 0]
]
export const YELLOW_BLUE_STOPS: ColorStop[] = [
[0, 255, 255, 0],
[1, 0, 0, 255]
]

View File

@@ -0,0 +1,70 @@
import { mount } from '@vue/test-utils'
import { describe, expect, it } from 'vitest'
import GradientSlider from './GradientSlider.vue'
import type { ColorStop } from './gradients'
import { BRIGHTNESS_STOPS, interpolateStops } from './gradients'
const TEST_STOPS: ColorStop[] = [
[0, 0, 0, 0],
[1, 255, 255, 255]
]
function mountSlider(props: {
stops?: ColorStop[]
modelValue: number
min?: number
max?: number
step?: number
}) {
return mount(GradientSlider, {
props: { stops: TEST_STOPS, ...props }
})
}
describe('GradientSlider', () => {
it('passes min, max, step to SliderRoot', () => {
const wrapper = mountSlider({
modelValue: 50,
min: -100,
max: 100,
step: 5
})
const thumb = wrapper.find('[role="slider"]')
expect(thumb.attributes('aria-valuemin')).toBe('-100')
expect(thumb.attributes('aria-valuemax')).toBe('100')
})
it('renders slider root with track and thumb', () => {
const wrapper = mountSlider({ modelValue: 0 })
expect(wrapper.find('[data-slider-impl]').exists()).toBe(true)
expect(wrapper.find('[role="slider"]').exists()).toBe(true)
})
it('does not render SliderRange', () => {
const wrapper = mountSlider({ modelValue: 50 })
expect(wrapper.find('[data-slot="slider-range"]').exists()).toBe(false)
})
})
describe('interpolateStops', () => {
it('returns start color at t=0', () => {
expect(interpolateStops(BRIGHTNESS_STOPS, 0)).toBe('rgb(0,0,0)')
})
it('returns end color at t=1', () => {
expect(interpolateStops(BRIGHTNESS_STOPS, 1)).toBe('rgb(255,255,255)')
})
it('returns midpoint color at t=0.5', () => {
expect(interpolateStops(BRIGHTNESS_STOPS, 0.5)).toBe('rgb(128,128,128)')
})
it('clamps values below 0', () => {
expect(interpolateStops(BRIGHTNESS_STOPS, -1)).toBe('rgb(0,0,0)')
})
it('clamps values above 1', () => {
expect(interpolateStops(BRIGHTNESS_STOPS, 2)).toBe('rgb(255,255,255)')
})
})

View File

@@ -0,0 +1,87 @@
<script setup lang="ts">
import { SliderRoot, SliderThumb, SliderTrack } from 'reka-ui'
import { computed, ref } from 'vue'
import type { ColorStop } from '@/components/colorcorrect/gradients'
import {
interpolateStops,
stopsToGradient
} from '@/components/colorcorrect/gradients'
import { cn } from '@/utils/tailwindUtil'
const {
stops,
min = 0,
max = 100,
step = 1,
disabled = false
} = defineProps<{
stops: ColorStop[]
min?: number
max?: number
step?: number
disabled?: boolean
}>()
const modelValue = defineModel<number>({ required: true })
const sliderValue = computed({
get: () => [modelValue.value],
set: (v: number[]) => {
if (v.length) modelValue.value = v[0]
}
})
const gradient = computed(() => stopsToGradient(stops))
const thumbColor = computed(() => {
const t = max === min ? 0 : (modelValue.value - min) / (max - min)
return interpolateStops(stops, t)
})
const pressed = ref(false)
</script>
<template>
<SliderRoot
v-model="sliderValue"
:min="min"
:max="max"
:step="step"
:disabled="disabled"
:class="
cn(
'relative flex w-full touch-none items-center select-none',
'data-[disabled]:opacity-50'
)
"
:style="{ '--reka-slider-thumb-transform': 'translate(-50%, -50%)' }"
@slide-start="pressed = true"
@slide-move="pressed = true"
@slide-end="pressed = false"
>
<SliderTrack
:class="
cn(
'relative h-2.5 w-full grow cursor-pointer overflow-visible rounded-full',
'before:absolute before:-inset-2 before:block before:bg-transparent'
)
"
:style="{ background: gradient }"
>
<SliderThumb
:class="
cn(
'block size-4 shrink-0 cursor-grab rounded-full shadow-md ring-1 ring-black/25',
'transition-[color,box-shadow,background-color]',
'before:absolute before:-inset-1.5 before:block before:rounded-full before:bg-transparent',
'hover:ring-2 hover:ring-black/40 focus-visible:ring-2 focus-visible:ring-black/40 focus-visible:outline-hidden',
'disabled:pointer-events-none disabled:opacity-50',
{ 'cursor-grabbing': pressed }
)
"
:style="{ backgroundColor: thumbColor, top: '50%' }"
/>
</SliderTrack>
</SliderRoot>
</template>

View File

@@ -0,0 +1,243 @@
<template>
<div
class="widget-expands relative flex h-full w-full flex-col gap-1"
@pointerdown.stop
@pointermove.stop
@pointerup.stop
>
<div
class="relative min-h-0 flex-1 overflow-hidden rounded-[5px] bg-node-component-surface"
>
<div v-if="isLoading" class="flex size-full items-center justify-center">
<span class="text-sm">{{ $t('colorCorrect.loading') }}</span>
</div>
<div
v-else-if="!imageUrl"
class="flex size-full flex-col items-center justify-center text-center"
>
<i class="mb-2 icon-[lucide--image] h-12 w-12" />
<p class="text-sm">{{ $t('colorCorrect.noInputImage') }}</p>
</div>
<template v-else>
<img
:src="imageUrl"
:alt="$t('colorCorrect.previewAlt')"
draggable="false"
:style="{ filter: filterStyle }"
class="block size-full select-none object-contain"
@load="handleImageLoad"
@error="handleImageError"
@dragstart.prevent
/>
<div
v-if="temperatureOverlayStyle"
class="pointer-events-none absolute inset-0"
:style="temperatureOverlayStyle"
/>
</template>
</div>
<div class="grid shrink-0 grid-cols-[auto_1fr_auto] gap-x-2 gap-y-1">
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorCorrect.temperature') }}
</label>
<GradientSlider
v-model="temperature"
:min="-100"
:max="100"
:step="1"
:stops="TEMPERATURE_STOPS"
class="h-7"
/>
<input
v-model.number="temperature"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorCorrect.hue') }}
</label>
<GradientSlider
v-model="hue"
:min="-90"
:max="90"
:step="1"
:stops="HUE_STOPS"
class="h-7"
/>
<input
v-model.number="hue"
type="number"
:min="-90"
:max="90"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorCorrect.brightness') }}
</label>
<GradientSlider
v-model="brightness"
:min="-100"
:max="100"
:step="1"
:stops="BRIGHTNESS_STOPS"
class="h-7"
/>
<input
v-model.number="brightness"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorCorrect.contrast') }}
</label>
<GradientSlider
v-model="contrast"
:min="-100"
:max="100"
:step="1"
:stops="CONTRAST_STOPS"
class="h-7"
/>
<input
v-model.number="contrast"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorCorrect.saturation') }}
</label>
<GradientSlider
v-model="saturation"
:min="-100"
:max="100"
:step="1"
:stops="SATURATION_STOPS"
class="h-7"
/>
<input
v-model.number="saturation"
type="number"
:min="-100"
:max="100"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
<label class="content-center text-xs text-node-component-slot-text">
{{ $t('colorCorrect.gamma') }}
</label>
<GradientSlider
v-model="gamma"
:min="0.2"
:max="2.2"
:step="1"
:stops="GAMMA_STOPS"
class="h-7"
/>
<input
v-model.number="gamma"
type="number"
:min="0.2"
:max="2.2"
:step="1"
class="h-7 w-14 rounded-lg border-none bg-component-node-widget-background px-2 text-xs text-component-node-foreground focus:outline-0"
/>
</div>
<Button
variant="secondary"
size="md"
class="shrink-0 gap-2 rounded-lg border border-component-node-border bg-component-node-background text-xs text-muted-foreground hover:text-base-foreground"
@click="resetAll"
>
<i class="icon-[lucide--undo-2]" />
{{ $t('colorCorrect.reset') }}
</Button>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import GradientSlider from '@/components/colorcorrect/GradientSlider.vue'
import {
BRIGHTNESS_STOPS,
CONTRAST_STOPS,
GAMMA_STOPS,
HUE_STOPS,
SATURATION_STOPS,
TEMPERATURE_STOPS
} from '@/components/colorcorrect/gradients'
import Button from '@/components/ui/button/Button.vue'
import { useColorCorrect } from '@/composables/useColorCorrect'
import type { ColorCorrectSettings } from '@/lib/litegraph/src/types/widgets'
import type { NodeId } from '@/platform/workflow/validation/schemas/workflowSchema'
const props = defineProps<{
nodeId: NodeId
}>()
const modelValue = defineModel<ColorCorrectSettings>({
default: () => ({
temperature: 0,
hue: 0,
brightness: 0,
contrast: 0,
saturation: 0,
gamma: 1.0
})
})
function fieldAccessor(key: keyof ColorCorrectSettings) {
return computed({
get: () => modelValue.value[key],
set: (v) => {
modelValue.value = { ...modelValue.value, [key]: v }
}
})
}
const temperature = fieldAccessor('temperature')
const hue = fieldAccessor('hue')
const brightness = fieldAccessor('brightness')
const contrast = fieldAccessor('contrast')
const saturation = fieldAccessor('saturation')
const gamma = fieldAccessor('gamma')
const {
imageUrl,
isLoading,
filterStyle,
temperatureOverlayStyle,
handleImageLoad,
handleImageError
} = useColorCorrect(props.nodeId, modelValue)
function resetAll() {
modelValue.value = {
temperature: 0,
hue: 0,
brightness: 0,
contrast: 0,
saturation: 0,
gamma: 1.0
}
}
</script>

View File

@@ -0,0 +1,79 @@
export type ColorStop = readonly [
offset: number,
r: number,
g: number,
b: number
]
export const HUE_STOPS: ColorStop[] = [
[0, 255, 0, 0],
[1 / 6, 255, 255, 0],
[2 / 6, 0, 255, 0],
[3 / 6, 0, 255, 255],
[4 / 6, 0, 0, 255],
[5 / 6, 255, 0, 255],
[1, 255, 0, 0]
]
export const SATURATION_STOPS: ColorStop[] = [
[0, 128, 128, 128],
[1, 255, 0, 0]
]
export const BRIGHTNESS_STOPS: ColorStop[] = [
[0, 0, 0, 0],
[1, 255, 255, 255]
]
export const TEMPERATURE_STOPS: ColorStop[] = [
[0, 68, 136, 255],
[0.5, 255, 255, 255],
[1, 255, 136, 0]
]
export const CONTRAST_STOPS: ColorStop[] = [
[0, 136, 136, 136],
[0.4, 68, 68, 68],
[0.6, 187, 187, 187],
[0.8, 0, 0, 0],
[1, 255, 255, 255]
]
export const GAMMA_STOPS: ColorStop[] = [
[0, 34, 34, 34],
[0.3, 85, 85, 85],
[0.5, 153, 153, 153],
[0.7, 204, 204, 204],
[1, 255, 255, 255]
]
export function stopsToGradient(stops: ColorStop[]): string {
const colors = stops.map(
([offset, r, g, b]) => `rgb(${r},${g},${b}) ${offset * 100}%`
)
return `linear-gradient(to right, ${colors.join(', ')})`
}
export function interpolateStops(stops: ColorStop[], t: number): string {
const clamped = Math.max(0, Math.min(1, t))
if (clamped <= stops[0][0]) {
const [, r, g, b] = stops[0]
return `rgb(${r},${g},${b})`
}
for (let i = 0; i < stops.length - 1; i++) {
const [o1, r1, g1, b1] = stops[i]
const [o2, r2, g2, b2] = stops[i + 1]
if (clamped >= o1 && clamped <= o2) {
const f = o2 === o1 ? 0 : (clamped - o1) / (o2 - o1)
const r = Math.round(r1 + (r2 - r1) * f)
const g = Math.round(g1 + (g2 - g1) * f)
const b = Math.round(b1 + (b2 - b1) * f)
return `rgb(${r},${g},${b})`
}
}
const [, r, g, b] = stops[stops.length - 1]
return `rgb(${r},${g},${b})`
}

View File

@@ -0,0 +1,113 @@
<template>
<svg
ref="svgRef"
viewBox="0 0 1 1"
preserveAspectRatio="xMidYMid meet"
class="aspect-square w-full cursor-crosshair rounded-[5px] bg-node-component-surface"
@pointerdown="handleSvgPointerDown"
@contextmenu.prevent.stop
>
<line
v-for="v in [0.25, 0.5, 0.75]"
:key="'h' + v"
:x1="0"
:y1="v"
:x2="1"
:y2="v"
stroke="currentColor"
stroke-opacity="0.1"
stroke-width="0.003"
/>
<line
v-for="v in [0.25, 0.5, 0.75]"
:key="'v' + v"
:x1="v"
:y1="0"
:x2="v"
:y2="1"
stroke="currentColor"
stroke-opacity="0.1"
stroke-width="0.003"
/>
<line
x1="0"
y1="1"
x2="1"
y2="0"
stroke="currentColor"
stroke-opacity="0.15"
stroke-width="0.003"
/>
<path
v-if="histogramPath"
:d="histogramPath"
:fill="curveColor"
fill-opacity="0.15"
stroke="none"
/>
<path
:d="curvePath"
fill="none"
:stroke="curveColor"
stroke-width="0.008"
stroke-linecap="round"
/>
<circle
v-for="(point, i) in modelValue"
:key="i"
:cx="point[0]"
:cy="1 - point[1]"
r="0.02"
:fill="curveColor"
stroke="white"
stroke-width="0.004"
class="cursor-grab"
@pointerdown.stop="startDrag(i, $event)"
/>
</svg>
</template>
<script setup lang="ts">
import { computed, useTemplateRef } from 'vue'
import { useCurveEditor } from '@/composables/useCurveEditor'
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
const { curveColor = 'white', histogram } = defineProps<{
curveColor?: string
histogram?: Uint32Array | null
}>()
const modelValue = defineModel<CurvePoint[]>({
required: true
})
const svgRef = useTemplateRef<SVGSVGElement>('svgRef')
const { curvePath, handleSvgPointerDown, startDrag } = useCurveEditor({
svgRef,
modelValue
})
const histogramPath = computed(() => {
if (!histogram?.length) return ''
const sorted = Array.from(histogram).sort((a, b) => a - b)
const max = sorted[Math.floor(255 * 0.995)]
if (max === 0) return ''
const step = 1 / 255
let d = 'M0,1'
for (let i = 0; i < 256; i++) {
const x = i * step
const y = 1 - Math.min(1, histogram[i] / max)
d += ` L${x.toFixed(4)},${y.toFixed(4)}`
}
d += ' L1,1 Z'
return d
})
</script>

View File

@@ -0,0 +1,180 @@
<template>
<div
class="widget-expands relative flex h-full w-full flex-col gap-1"
@pointerdown.stop
@pointermove.stop
@pointerup.stop
>
<div
class="relative min-h-0 flex-1 overflow-hidden rounded-[5px] bg-node-component-surface"
>
<div
v-if="!imageUrl"
class="flex size-full flex-col items-center justify-center text-center"
>
<i class="mb-2 icon-[lucide--image] h-12 w-12" />
<p class="text-sm">{{ $t('colorCurves.noInputImage') }}</p>
</div>
<canvas
v-show="imageUrl"
ref="glCanvas"
class="block size-full select-none object-contain"
/>
</div>
<div
class="flex h-8 shrink-0 items-center gap-1 rounded-sm bg-component-node-widget-background p-1"
>
<Button
v-for="tab in TABS"
:key="tab.value"
variant="textonly"
size="unset"
:class="
cn(
'flex-1 self-stretch px-2 text-xs transition-colors',
activeTab === tab.value
? 'rounded-sm bg-component-node-widget-background-selected text-base-foreground'
: 'text-node-text-muted hover:text-node-text'
)
"
@click="activeTab = tab.value"
>
{{ $t(tab.label) }}
</Button>
</div>
<template v-for="tab in TABS" :key="tab.value">
<div v-show="activeTab === tab.value" class="mt-1 shrink-0">
<CurveEditor
v-model="channels[tab.value].value"
:curve-color="tab.color"
:histogram="histogram?.[tab.histogramChannel]"
/>
</div>
</template>
<Button
variant="secondary"
size="md"
class="shrink-0 gap-2 rounded-lg border border-component-node-border bg-component-node-background text-xs text-muted-foreground hover:text-base-foreground"
@click="resetAll"
>
<i class="icon-[lucide--undo-2]" />
{{ $t('colorCurves.reset') }}
</Button>
</div>
</template>
<script setup lang="ts">
import { computed, ref, shallowRef } from 'vue'
import CurveEditor from '@/components/colorcurves/CurveEditor.vue'
import Button from '@/components/ui/button/Button.vue'
import { useColorCurves } from '@/composables/useColorCurves'
import { useWebGLColorCurves } from '@/composables/useWebGLColorCurves'
import type {
ColorCurvesSettings,
CurvePoint
} from '@/lib/litegraph/src/types/widgets'
import type { NodeId } from '@/platform/workflow/validation/schemas/workflowSchema'
import { cn } from '@/utils/tailwindUtil'
const DEFAULT_CURVE: CurvePoint[] = [
[0, 0],
[1, 1]
]
type Channel = keyof ColorCurvesSettings
type HistogramChannel = 'luminance' | 'red' | 'green' | 'blue'
const TABS: {
value: Channel
label: string
color: string
histogramChannel: HistogramChannel
}[] = [
{
value: 'rgb',
label: 'colorCurves.rgb',
color: 'white',
histogramChannel: 'luminance'
},
{
value: 'red',
label: 'colorCurves.red',
color: '#ff4444',
histogramChannel: 'red'
},
{
value: 'green',
label: 'colorCurves.green',
color: '#44cc44',
histogramChannel: 'green'
},
{
value: 'blue',
label: 'colorCurves.blue',
color: '#4488ff',
histogramChannel: 'blue'
}
]
const props = defineProps<{
nodeId: NodeId
}>()
const modelValue = defineModel<ColorCurvesSettings>({
default: () => ({
rgb: [
[0, 0],
[1, 1]
] as CurvePoint[],
red: [
[0, 0],
[1, 1]
] as CurvePoint[],
green: [
[0, 0],
[1, 1]
] as CurvePoint[],
blue: [
[0, 0],
[1, 1]
] as CurvePoint[]
})
})
const activeTab = ref<string>('rgb')
function channelAccessor(channel: Channel) {
return computed({
get: () => modelValue.value[channel],
set: (v) => {
modelValue.value = { ...modelValue.value, [channel]: v }
}
})
}
const channels = {
rgb: channelAccessor('rgb'),
red: channelAccessor('red'),
green: channelAccessor('green'),
blue: channelAccessor('blue')
}
const glCanvas = shallowRef<HTMLCanvasElement | null>(null)
const { imageUrl, histogram } = useColorCurves(props.nodeId)
useWebGLColorCurves(glCanvas, imageUrl, modelValue)
function resetAll() {
modelValue.value = {
rgb: [...DEFAULT_CURVE],
red: [...DEFAULT_CURVE],
green: [...DEFAULT_CURVE],
blue: [...DEFAULT_CURVE]
}
}
</script>

View File

@@ -0,0 +1,102 @@
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
/**
* Monotone cubic Hermite interpolation.
* Produces a smooth curve that passes through all control points
* without overshooting (monotone property).
*
* Returns a function that evaluates y for any x in [0, 1].
*/
function createMonotoneInterpolator(
points: CurvePoint[]
): (x: number) => number {
if (points.length === 0) return () => 0
if (points.length === 1) return () => points[0][1]
const sorted = [...points].sort((a, b) => a[0] - b[0])
const n = sorted.length
const xs = sorted.map((p) => p[0])
const ys = sorted.map((p) => p[1])
const deltas: number[] = []
const slopes: number[] = []
for (let i = 0; i < n - 1; i++) {
const dx = xs[i + 1] - xs[i]
deltas.push(dx === 0 ? 0 : (ys[i + 1] - ys[i]) / dx)
}
slopes.push(deltas[0] ?? 0)
for (let i = 1; i < n - 1; i++) {
if (deltas[i - 1] * deltas[i] <= 0) {
slopes.push(0)
} else {
slopes.push((deltas[i - 1] + deltas[i]) / 2)
}
}
slopes.push(deltas[n - 2] ?? 0)
for (let i = 0; i < n - 1; i++) {
if (deltas[i] === 0) {
slopes[i] = 0
slopes[i + 1] = 0
} else {
const alpha = slopes[i] / deltas[i]
const beta = slopes[i + 1] / deltas[i]
const s = alpha * alpha + beta * beta
if (s > 9) {
const t = 3 / Math.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
}
}
}
return (x: number): number => {
if (x <= xs[0]) return ys[0]
if (x >= xs[n - 1]) return ys[n - 1]
let lo = 0
let hi = n - 1
while (lo < hi - 1) {
const mid = (lo + hi) >> 1
if (xs[mid] <= x) lo = mid
else hi = mid
}
const dx = xs[hi] - xs[lo]
if (dx === 0) return ys[lo]
const t = (x - xs[lo]) / dx
const t2 = t * t
const t3 = t2 * t
const h00 = 2 * t3 - 3 * t2 + 1
const h10 = t3 - 2 * t2 + t
const h01 = -2 * t3 + 3 * t2
const h11 = t3 - t2
return (
h00 * ys[lo] +
h10 * dx * slopes[lo] +
h01 * ys[hi] +
h11 * dx * slopes[hi]
)
}
}
/**
* Generate a 256-entry lookup table from curve control points.
* Points are in [0, 1] space; output is clamped to [0, 255] as Uint8.
*/
export function curvesToLUT(points: CurvePoint[]): Uint8Array {
const lut = new Uint8Array(256)
const interpolate = createMonotoneInterpolator(points)
for (let i = 0; i < 256; i++) {
const x = i / 255
const y = interpolate(x)
lut[i] = Math.max(0, Math.min(255, Math.round(y * 255)))
}
return lut
}

View File

@@ -0,0 +1,49 @@
import { onMounted, ref, watch } from 'vue'
import type { LGraphNode, NodeId } from '@/lib/litegraph/src/LGraphNode'
import { app } from '@/scripts/app'
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
export function useColorBalance(nodeId: NodeId) {
const nodeOutputStore = useNodeOutputStore()
const node = ref<LGraphNode | null>(null)
const imageUrl = ref<string | null>(null)
function getInputImageUrl(): string | null {
if (!node.value) return null
const inputNode = node.value.getInputNode(0)
if (!inputNode) return null
const urls = nodeOutputStore.getNodeImageUrls(inputNode)
if (urls?.length) return urls[0]
return null
}
function updateImageUrl() {
imageUrl.value = getInputImageUrl()
}
watch(
() => nodeOutputStore.nodeOutputs,
() => updateImageUrl(),
{ deep: true }
)
watch(
() => nodeOutputStore.nodePreviewImages,
() => updateImageUrl(),
{ deep: true }
)
onMounted(() => {
if (nodeId != null) {
node.value = app.rootGraph?.getNodeById(nodeId) || null
}
updateImageUrl()
})
return { imageUrl }
}

View File

@@ -0,0 +1,108 @@
import type { Ref } from 'vue'
import { computed, onMounted, ref, watch } from 'vue'
import type { LGraphNode, NodeId } from '@/lib/litegraph/src/LGraphNode'
import type { ColorCorrectSettings } from '@/lib/litegraph/src/types/widgets'
import { app } from '@/scripts/app'
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
export function useColorCorrect(
nodeId: NodeId,
modelValue: Ref<ColorCorrectSettings>
) {
const nodeOutputStore = useNodeOutputStore()
const node = ref<LGraphNode | null>(null)
const imageUrl = ref<string | null>(null)
const isLoading = ref(false)
const getInputImageUrl = (): string | null => {
if (!node.value) return null
const inputNode = node.value.getInputNode(0)
if (!inputNode) return null
const urls = nodeOutputStore.getNodeImageUrls(inputNode)
if (urls?.length) return urls[0]
return null
}
const updateImageUrl = () => {
imageUrl.value = getInputImageUrl()
}
const filterStyle = computed(() => {
const v = modelValue.value
const filters: string[] = []
if (v.brightness !== 0) {
filters.push(`brightness(${1 + v.brightness / 100})`)
}
if (v.contrast !== 0) {
filters.push(`contrast(${1 + v.contrast / 100})`)
}
if (v.saturation !== 0) {
filters.push(`saturate(${1 + v.saturation / 100})`)
}
if (v.hue !== 0) {
filters.push(`hue-rotate(${v.hue}deg)`)
}
return filters.length > 0 ? filters.join(' ') : 'none'
})
const temperatureOverlayStyle = computed(() => {
const temp = modelValue.value.temperature
if (temp === 0) return null
const opacity = (Math.abs(temp) / 100) * 0.15
const color =
temp > 0
? `rgba(255, 140, 0, ${opacity})`
: `rgba(0, 100, 255, ${opacity})`
return {
backgroundColor: color
}
})
const handleImageLoad = () => {
isLoading.value = false
}
const handleImageError = () => {
isLoading.value = false
imageUrl.value = null
}
const initialize = () => {
if (nodeId != null) {
node.value = app.rootGraph?.getNodeById(nodeId) || null
}
updateImageUrl()
}
watch(
() => nodeOutputStore.nodeOutputs,
() => updateImageUrl(),
{ deep: true }
)
watch(
() => nodeOutputStore.nodePreviewImages,
() => updateImageUrl(),
{ deep: true }
)
onMounted(initialize)
return {
imageUrl,
isLoading,
filterStyle,
temperatureOverlayStyle,
handleImageLoad,
handleImageError
}
}

View File

@@ -0,0 +1,108 @@
import { onMounted, ref, shallowRef, watch } from 'vue'
import type { LGraphNode, NodeId } from '@/lib/litegraph/src/LGraphNode'
import { app } from '@/scripts/app'
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
interface ImageHistogram {
red: Uint32Array
green: Uint32Array
blue: Uint32Array
luminance: Uint32Array
}
function computeHistogram(imageUrl: string): Promise<ImageHistogram> {
return new Promise((resolve, reject) => {
const img = new Image()
img.crossOrigin = 'anonymous'
img.onload = () => {
const canvas = document.createElement('canvas')
const maxDim = 512
const scale = Math.min(1, maxDim / Math.max(img.width, img.height))
canvas.width = Math.round(img.width * scale)
canvas.height = Math.round(img.height * scale)
const ctx = canvas.getContext('2d')
if (!ctx) return reject(new Error('No 2d context'))
ctx.drawImage(img, 0, 0, canvas.width, canvas.height)
const { data } = ctx.getImageData(0, 0, canvas.width, canvas.height)
const red = new Uint32Array(256)
const green = new Uint32Array(256)
const blue = new Uint32Array(256)
const luminance = new Uint32Array(256)
for (let i = 0; i < data.length; i += 4) {
const r = data[i]
const g = data[i + 1]
const b = data[i + 2]
red[r]++
green[g]++
blue[b]++
luminance[Math.round(0.2126 * r + 0.7152 * g + 0.0722 * b)]++
}
resolve({ red, green, blue, luminance })
}
img.onerror = () => reject(new Error('Failed to load image'))
img.src = imageUrl
})
}
export function useColorCurves(nodeId: NodeId) {
const nodeOutputStore = useNodeOutputStore()
const node = ref<LGraphNode | null>(null)
const imageUrl = ref<string | null>(null)
const histogram = shallowRef<ImageHistogram | null>(null)
function getInputImageUrl(): string | null {
if (!node.value) return null
const inputNode = node.value.getInputNode(0)
if (!inputNode) return null
const urls = nodeOutputStore.getNodeImageUrls(inputNode)
if (urls?.length) return urls[0]
return null
}
function updateImageUrl() {
imageUrl.value = getInputImageUrl()
}
watch(imageUrl, async (url) => {
if (!url) {
histogram.value = null
return
}
try {
histogram.value = await computeHistogram(url)
} catch {
histogram.value = null
}
})
watch(
() => nodeOutputStore.nodeOutputs,
() => updateImageUrl(),
{ deep: true }
)
watch(
() => nodeOutputStore.nodePreviewImages,
() => updateImageUrl(),
{ deep: true }
)
onMounted(() => {
if (nodeId != null) {
node.value = app.rootGraph?.getNodeById(nodeId) || null
}
updateImageUrl()
})
return { imageUrl, histogram }
}

View File

@@ -0,0 +1,101 @@
import { computed, ref } from 'vue'
import type { Ref } from 'vue'
import { curveMonotoneX, line } from 'd3-shape'
import type { CurvePoint } from '@/lib/litegraph/src/types/widgets'
interface UseCurveEditorOptions {
svgRef: Ref<SVGSVGElement | null>
modelValue: Ref<CurvePoint[]>
}
export function useCurveEditor({ svgRef, modelValue }: UseCurveEditorOptions) {
const dragIndex = ref(-1)
const sortedPoints = computed(() => {
const points = modelValue.value
if (!Array.isArray(points)) return []
return [...points].sort((a, b) => a[0] - b[0])
})
const curvePath = computed(() => {
const points = sortedPoints.value
if (points.length < 2) return ''
const lineGenerator = line<CurvePoint>()
.x((d) => d[0])
.y((d) => 1 - d[1])
.curve(curveMonotoneX)
return lineGenerator(points) ?? ''
})
function svgCoords(e: PointerEvent): [number, number] {
const svg = svgRef.value
if (!svg) return [0, 0]
const pt = svg.createSVGPoint()
pt.x = e.clientX
pt.y = e.clientY
const ctm = svg.getScreenCTM()
if (!ctm) return [0, 0]
const svgPt = pt.matrixTransform(ctm.inverse())
return [
Math.max(0, Math.min(1, svgPt.x)),
Math.max(0, Math.min(1, 1 - svgPt.y))
]
}
function handleSvgPointerDown(e: PointerEvent) {
if (e.button !== 0) return
const [x, y] = svgCoords(e)
const newPoints: CurvePoint[] = [...modelValue.value, [x, y]]
modelValue.value = newPoints
const newIndex = newPoints.length - 1
startDrag(newIndex, e)
}
function startDrag(index: number, e: PointerEvent) {
if (e.button === 2 || (e.button === 0 && e.ctrlKey)) {
if (modelValue.value.length > 2) {
const newPoints = [...modelValue.value]
newPoints.splice(index, 1)
modelValue.value = newPoints
}
return
}
dragIndex.value = index
const svg = svgRef.value
if (!svg) return
svg.setPointerCapture(e.pointerId)
const onMove = (ev: PointerEvent) => {
if (dragIndex.value < 0) return
const [x, y] = svgCoords(ev)
const newPoints = [...modelValue.value]
newPoints[dragIndex.value] = [x, y]
modelValue.value = newPoints
}
const onUp = () => {
dragIndex.value = -1
svg.removeEventListener('pointermove', onMove)
svg.removeEventListener('pointerup', onUp)
}
svg.addEventListener('pointermove', onMove)
svg.addEventListener('pointerup', onUp)
}
return {
curvePath,
handleSvgPointerDown,
startDrag
}
}

View File

@@ -0,0 +1,225 @@
import type { Ref, ShallowRef } from 'vue'
import { onBeforeUnmount, watch } from 'vue'
import type { ColorBalanceSettings } from '@/lib/litegraph/src/types/widgets'
const VERTEX_SRC = `#version 300 es
in vec2 a_position;
in vec2 a_texCoord;
out vec2 v_texCoord;
void main() {
gl_Position = vec4(a_position, 0.0, 1.0);
v_texCoord = a_texCoord;
}
`
const FRAGMENT_SRC = `#version 300 es
precision highp float;
uniform sampler2D u_image;
uniform vec3 u_shadows;
uniform vec3 u_midtones;
uniform vec3 u_highlights;
in vec2 v_texCoord;
out vec4 outColor;
void main() {
vec4 texel = texture(u_image, v_texCoord);
vec3 color = texel.rgb;
float luminance = dot(color, vec3(0.2126, 0.7152, 0.0722));
float st = clamp(luminance * 2.0, 0.0, 1.0);
float shadowWeight = 1.0 - st * st * (3.0 - 2.0 * st);
float ht = clamp(luminance * 2.0 - 1.0, 0.0, 1.0);
float highlightWeight = ht * ht * (3.0 - 2.0 * ht);
float midtoneWeight = 1.0 - shadowWeight - highlightWeight;
vec3 offset = shadowWeight * u_shadows
+ midtoneWeight * u_midtones
+ highlightWeight * u_highlights;
outColor = vec4(clamp(color + offset, 0.0, 1.0), texel.a);
}
`
function compileShader(
gl: WebGL2RenderingContext,
type: number,
source: string
): WebGLShader {
const shader = gl.createShader(type)!
gl.shaderSource(shader, source)
gl.compileShader(shader)
return shader
}
interface GLState {
gl: WebGL2RenderingContext
program: WebGLProgram
texture: WebGLTexture
uShadows: WebGLUniformLocation
uMidtones: WebGLUniformLocation
uHighlights: WebGLUniformLocation
vao: WebGLVertexArrayObject
}
function initGL(canvas: HTMLCanvasElement): GLState | null {
const gl = canvas.getContext('webgl2', {
premultipliedAlpha: false,
alpha: true
})
if (!gl) return null
const vs = compileShader(gl, gl.VERTEX_SHADER, VERTEX_SRC)
const fs = compileShader(gl, gl.FRAGMENT_SHADER, FRAGMENT_SRC)
const program = gl.createProgram()!
gl.attachShader(program, vs)
gl.attachShader(program, fs)
gl.linkProgram(program)
gl.useProgram(program)
gl.deleteShader(vs)
gl.deleteShader(fs)
const vao = gl.createVertexArray()!
gl.bindVertexArray(vao)
const positions = new Float32Array([
-1, -1, 0, 1, 1, -1, 1, 1, -1, 1, 0, 0, -1, 1, 0, 0, 1, -1, 1, 1, 1, 1, 1, 0
])
const buf = gl.createBuffer()!
gl.bindBuffer(gl.ARRAY_BUFFER, buf)
gl.bufferData(gl.ARRAY_BUFFER, positions, gl.STATIC_DRAW)
const aPos = gl.getAttribLocation(program, 'a_position')
gl.enableVertexAttribArray(aPos)
gl.vertexAttribPointer(aPos, 2, gl.FLOAT, false, 16, 0)
const aTex = gl.getAttribLocation(program, 'a_texCoord')
gl.enableVertexAttribArray(aTex)
gl.vertexAttribPointer(aTex, 2, gl.FLOAT, false, 16, 8)
const texture = gl.createTexture()!
gl.activeTexture(gl.TEXTURE0)
gl.bindTexture(gl.TEXTURE_2D, texture)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR)
gl.uniform1i(gl.getUniformLocation(program, 'u_image'), 0)
return {
gl,
program,
texture,
uShadows: gl.getUniformLocation(program, 'u_shadows')!,
uMidtones: gl.getUniformLocation(program, 'u_midtones')!,
uHighlights: gl.getUniformLocation(program, 'u_highlights')!,
vao
}
}
function uploadTexture(state: GLState, img: HTMLImageElement) {
const { gl, texture } = state
gl.canvas.width = img.naturalWidth
gl.canvas.height = img.naturalHeight
gl.viewport(0, 0, img.naturalWidth, img.naturalHeight)
gl.bindTexture(gl.TEXTURE_2D, texture)
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, img)
}
function render(state: GLState, settings: ColorBalanceSettings) {
const { gl, uShadows, uMidtones, uHighlights } = state
gl.uniform3f(
uShadows,
settings.shadows_red / 100,
settings.shadows_green / 100,
settings.shadows_blue / 100
)
gl.uniform3f(
uMidtones,
settings.midtones_red / 100,
settings.midtones_green / 100,
settings.midtones_blue / 100
)
gl.uniform3f(
uHighlights,
settings.highlights_red / 100,
settings.highlights_green / 100,
settings.highlights_blue / 100
)
gl.drawArrays(gl.TRIANGLES, 0, 6)
}
export function useWebGLColorBalance(
canvasRef: ShallowRef<HTMLCanvasElement | null>,
imageUrl: Ref<string | null>,
settings: Ref<ColorBalanceSettings>
) {
let state: GLState | null = null
let currentImage: HTMLImageElement | null = null
let pendingRaf = 0
function scheduleRender() {
if (!state || !currentImage?.complete) return
if (pendingRaf) return
pendingRaf = requestAnimationFrame(() => {
pendingRaf = 0
if (state) render(state, settings.value)
})
}
function loadImage(url: string | null) {
currentImage = null
if (!url || !state) return
const img = new Image()
img.crossOrigin = 'anonymous'
img.onload = () => {
if (!state) return
currentImage = img
uploadTexture(state, img)
render(state, settings.value)
}
img.src = url
}
watch(
canvasRef,
(canvas) => {
if (!canvas) {
cleanup()
return
}
state = initGL(canvas)
if (state && imageUrl.value) loadImage(imageUrl.value)
},
{ immediate: true }
)
watch(imageUrl, (url) => loadImage(url))
watch(settings, () => scheduleRender(), { deep: true })
function cleanup() {
if (pendingRaf) {
cancelAnimationFrame(pendingRaf)
pendingRaf = 0
}
if (state) {
const { gl, program, texture, vao } = state
gl.deleteTexture(texture)
gl.deleteVertexArray(vao)
gl.deleteProgram(program)
state = null
}
currentImage = null
}
onBeforeUnmount(cleanup)
}

View File

@@ -0,0 +1,252 @@
import type { Ref, ShallowRef } from 'vue'
import { onBeforeUnmount, watch } from 'vue'
import { curvesToLUT } from '@/components/colorcurves/curveUtils'
import type { ColorCurvesSettings } from '@/lib/litegraph/src/types/widgets'
const VERTEX_SRC = `#version 300 es
in vec2 a_position;
in vec2 a_texCoord;
out vec2 v_texCoord;
void main() {
gl_Position = vec4(a_position, 0.0, 1.0);
v_texCoord = a_texCoord;
}
`
const FRAGMENT_SRC = `#version 300 es
precision highp float;
uniform sampler2D u_image;
uniform sampler2D u_lut;
in vec2 v_texCoord;
out vec4 outColor;
void main() {
vec4 texel = texture(u_image, v_texCoord);
float r = texture(u_lut, vec2(texel.r, 0.25)).r;
float g = texture(u_lut, vec2(texel.g, 0.25)).g;
float b = texture(u_lut, vec2(texel.b, 0.25)).b;
r = texture(u_lut, vec2(r, 0.75)).a;
g = texture(u_lut, vec2(g, 0.75)).a;
b = texture(u_lut, vec2(b, 0.75)).a;
outColor = vec4(r, g, b, texel.a);
}
`
function compileShader(
gl: WebGL2RenderingContext,
type: number,
source: string
): WebGLShader {
const shader = gl.createShader(type)!
gl.shaderSource(shader, source)
gl.compileShader(shader)
return shader
}
interface GLState {
gl: WebGL2RenderingContext
program: WebGLProgram
imageTexture: WebGLTexture
lutTexture: WebGLTexture
vao: WebGLVertexArrayObject
}
function initGL(canvas: HTMLCanvasElement): GLState | null {
const gl = canvas.getContext('webgl2', {
premultipliedAlpha: false,
alpha: true
})
if (!gl) return null
const vs = compileShader(gl, gl.VERTEX_SHADER, VERTEX_SRC)
const fs = compileShader(gl, gl.FRAGMENT_SHADER, FRAGMENT_SRC)
const program = gl.createProgram()!
gl.attachShader(program, vs)
gl.attachShader(program, fs)
gl.linkProgram(program)
gl.useProgram(program)
gl.deleteShader(vs)
gl.deleteShader(fs)
const vao = gl.createVertexArray()!
gl.bindVertexArray(vao)
const positions = new Float32Array([
-1, -1, 0, 1, 1, -1, 1, 1, -1, 1, 0, 0, -1, 1, 0, 0, 1, -1, 1, 1, 1, 1, 1, 0
])
const buf = gl.createBuffer()!
gl.bindBuffer(gl.ARRAY_BUFFER, buf)
gl.bufferData(gl.ARRAY_BUFFER, positions, gl.STATIC_DRAW)
const aPos = gl.getAttribLocation(program, 'a_position')
gl.enableVertexAttribArray(aPos)
gl.vertexAttribPointer(aPos, 2, gl.FLOAT, false, 16, 0)
const aTex = gl.getAttribLocation(program, 'a_texCoord')
gl.enableVertexAttribArray(aTex)
gl.vertexAttribPointer(aTex, 2, gl.FLOAT, false, 16, 8)
const imageTexture = gl.createTexture()!
gl.activeTexture(gl.TEXTURE0)
gl.bindTexture(gl.TEXTURE_2D, imageTexture)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR)
const lutTexture = gl.createTexture()!
gl.activeTexture(gl.TEXTURE1)
gl.bindTexture(gl.TEXTURE_2D, lutTexture)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST)
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST)
gl.uniform1i(gl.getUniformLocation(program, 'u_image'), 0)
gl.uniform1i(gl.getUniformLocation(program, 'u_lut'), 1)
return { gl, program, imageTexture, lutTexture, vao }
}
function uploadImageTexture(state: GLState, img: HTMLImageElement) {
const { gl, imageTexture } = state
gl.canvas.width = img.naturalWidth
gl.canvas.height = img.naturalHeight
gl.viewport(0, 0, img.naturalWidth, img.naturalHeight)
gl.activeTexture(gl.TEXTURE0)
gl.bindTexture(gl.TEXTURE_2D, imageTexture)
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, img)
}
const DEFAULT_POINTS: [number, number][] = [
[0, 0],
[1, 1]
]
function uploadLUT(state: GLState, settings: ColorCurvesSettings) {
const { gl, lutTexture } = state
const redLUT = curvesToLUT(settings.red ?? DEFAULT_POINTS)
const greenLUT = curvesToLUT(settings.green ?? DEFAULT_POINTS)
const blueLUT = curvesToLUT(settings.blue ?? DEFAULT_POINTS)
const rgbLUT = curvesToLUT(settings.rgb ?? DEFAULT_POINTS)
// 256x2 RGBA texture
// Row 0 (y=0.25): R=redLUT, G=greenLUT, B=blueLUT, A=unused
// Row 1 (y=0.75): R/G/B/A = rgbLUT (master curve)
const data = new Uint8Array(256 * 2 * 4)
for (let i = 0; i < 256; i++) {
// Row 0: per-channel curves
const offset0 = i * 4
data[offset0] = redLUT[i]
data[offset0 + 1] = greenLUT[i]
data[offset0 + 2] = blueLUT[i]
data[offset0 + 3] = 255
// Row 1: RGB master curve
const offset1 = (256 + i) * 4
data[offset1] = rgbLUT[i]
data[offset1 + 1] = rgbLUT[i]
data[offset1 + 2] = rgbLUT[i]
data[offset1 + 3] = rgbLUT[i]
}
gl.activeTexture(gl.TEXTURE1)
gl.bindTexture(gl.TEXTURE_2D, lutTexture)
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA,
256,
2,
0,
gl.RGBA,
gl.UNSIGNED_BYTE,
data
)
}
function render(state: GLState) {
state.gl.drawArrays(state.gl.TRIANGLES, 0, 6)
}
export function useWebGLColorCurves(
canvasRef: ShallowRef<HTMLCanvasElement | null>,
imageUrl: Ref<string | null>,
settings: Ref<ColorCurvesSettings>
) {
let state: GLState | null = null
let currentImage: HTMLImageElement | null = null
let pendingRaf = 0
function scheduleRender() {
if (!state || !currentImage?.complete) return
if (pendingRaf) return
pendingRaf = requestAnimationFrame(() => {
pendingRaf = 0
if (state) {
uploadLUT(state, settings.value)
render(state)
}
})
}
function loadImage(url: string | null) {
currentImage = null
if (!url || !state) return
const img = new Image()
img.crossOrigin = 'anonymous'
img.onload = () => {
if (!state) return
currentImage = img
uploadImageTexture(state, img)
uploadLUT(state, settings.value)
render(state)
}
img.src = url
}
watch(
canvasRef,
(canvas) => {
if (!canvas) {
cleanup()
return
}
state = initGL(canvas)
if (state && imageUrl.value) loadImage(imageUrl.value)
},
{ immediate: true }
)
watch(imageUrl, (url) => loadImage(url))
watch(settings, () => scheduleRender(), { deep: true })
function cleanup() {
if (pendingRaf) {
cancelAnimationFrame(pendingRaf)
pendingRaf = 0
}
if (state) {
const { gl, program, imageTexture, lutTexture, vao } = state
gl.deleteTexture(imageTexture)
gl.deleteTexture(lutTexture)
gl.deleteVertexArray(vao)
gl.deleteProgram(program)
state = null
}
currentImage = null
}
onBeforeUnmount(cleanup)
}

View File

@@ -0,0 +1,13 @@
import { useExtensionService } from '@/services/extensionService'
useExtensionService().registerExtension({
name: 'Comfy.ColorBalance',
async nodeCreated(node) {
if (node.constructor.comfyClass !== 'ColorBalance') return
node.hideOutputImages = true
const [oldWidth, oldHeight] = node.size
node.setSize([Math.max(oldWidth, 350), Math.max(oldHeight, 450)])
}
})

View File

@@ -0,0 +1,13 @@
import { useExtensionService } from '@/services/extensionService'
useExtensionService().registerExtension({
name: 'Comfy.ColorCorrect',
async nodeCreated(node) {
if (node.constructor.comfyClass !== 'ColorCorrect') return
node.hideOutputImages = true
const [oldWidth, oldHeight] = node.size
node.setSize([Math.max(oldWidth, 350), Math.max(oldHeight, 500)])
}
})

View File

@@ -0,0 +1,13 @@
import { useExtensionService } from '@/services/extensionService'
useExtensionService().registerExtension({
name: 'Comfy.ColorCurves',
async nodeCreated(node) {
if (node.constructor.comfyClass !== 'ColorCurves') return
node.hideOutputImages = true
const [oldWidth, oldHeight] = node.size
node.setSize([Math.max(oldWidth, 350), Math.max(oldHeight, 500)])
}
})

View File

@@ -10,6 +10,9 @@ import './groupNode'
import './groupNodeManage'
import './groupOptions'
import './imageCompare'
import './colorBalance'
import './colorCorrect'
import './colorCurves'
import './imageCrop'
// load3d and saveMesh are loaded on-demand to defer THREE.js (~1.8MB)
// The lazy loader triggers loading when a 3D node is used

View File

@@ -109,6 +109,9 @@ export type IWidget =
| IAssetWidget
| IImageCropWidget
| IBoundingBoxWidget
| IColorCorrectWidget
| IColorBalanceWidget
| IColorCurvesWidget
export interface IBooleanWidget extends IBaseWidget<boolean, 'toggle'> {
type: 'toggle'
@@ -292,6 +295,63 @@ export interface IBoundingBoxWidget extends IBaseWidget<Bounds, 'boundingbox'> {
value: Bounds
}
export interface ColorCorrectSettings {
temperature: number
hue: number
brightness: number
contrast: number
saturation: number
gamma: number
}
/** Color correction widget for adjusting image parameters */
export interface IColorCorrectWidget extends IBaseWidget<
ColorCorrectSettings,
'colorcorrect'
> {
type: 'colorcorrect'
value: ColorCorrectSettings
}
export interface ColorBalanceSettings {
shadows_red: number
shadows_green: number
shadows_blue: number
midtones_red: number
midtones_green: number
midtones_blue: number
highlights_red: number
highlights_green: number
highlights_blue: number
}
/** Color balance widget for adjusting RGB channels per tonal range */
export interface IColorBalanceWidget extends IBaseWidget<
ColorBalanceSettings,
'colorbalance'
> {
type: 'colorbalance'
value: ColorBalanceSettings
}
export type CurvePoint = [x: number, y: number]
export interface ColorCurvesSettings {
rgb: CurvePoint[]
red: CurvePoint[]
green: CurvePoint[]
blue: CurvePoint[]
}
/** Color curves widget for per-channel tone curve adjustment */
export interface IColorCurvesWidget extends IBaseWidget<
ColorCurvesSettings,
'colorcurves'
> {
type: 'colorcurves'
value: ColorCurvesSettings
}
/**
* Valid widget types. TS cannot provide easily extensible type safety for this at present.
* Override linkedWidgets[]

View File

@@ -0,0 +1,22 @@
import type { IColorBalanceWidget } from '../types/widgets'
import { BaseWidget } from './BaseWidget'
import type { DrawWidgetOptions, WidgetEventOptions } from './BaseWidget'
/**
* Widget for color balance controls.
* This widget only has a Vue implementation.
*/
export class ColorBalanceWidget
extends BaseWidget<IColorBalanceWidget>
implements IColorBalanceWidget
{
override type = 'colorbalance' as const
drawWidget(ctx: CanvasRenderingContext2D, options: DrawWidgetOptions): void {
this.drawVueOnlyWarning(ctx, options, 'ColorBalance')
}
onClick(_options: WidgetEventOptions): void {
// This widget only has a Vue implementation
}
}

View File

@@ -0,0 +1,22 @@
import type { IColorCorrectWidget } from '../types/widgets'
import { BaseWidget } from './BaseWidget'
import type { DrawWidgetOptions, WidgetEventOptions } from './BaseWidget'
/**
* Widget for color correction controls.
* This widget only has a Vue implementation.
*/
export class ColorCorrectWidget
extends BaseWidget<IColorCorrectWidget>
implements IColorCorrectWidget
{
override type = 'colorcorrect' as const
drawWidget(ctx: CanvasRenderingContext2D, options: DrawWidgetOptions): void {
this.drawVueOnlyWarning(ctx, options, 'ColorCorrect')
}
onClick(_options: WidgetEventOptions): void {
// This widget only has a Vue implementation
}
}

View File

@@ -0,0 +1,22 @@
import type { IColorCurvesWidget } from '../types/widgets'
import { BaseWidget } from './BaseWidget'
import type { DrawWidgetOptions, WidgetEventOptions } from './BaseWidget'
/**
* Widget for color curves controls.
* This widget only has a Vue implementation.
*/
export class ColorCurvesWidget
extends BaseWidget<IColorCurvesWidget>
implements IColorCurvesWidget
{
override type = 'colorcurves' as const
drawWidget(ctx: CanvasRenderingContext2D, options: DrawWidgetOptions): void {
this.drawVueOnlyWarning(ctx, options, 'ColorCurves')
}
onClick(_options: WidgetEventOptions): void {
// This widget only has a Vue implementation
}
}

View File

@@ -12,6 +12,9 @@ import { AssetWidget } from './AssetWidget'
import { BaseWidget } from './BaseWidget'
import { BooleanWidget } from './BooleanWidget'
import { BoundingBoxWidget } from './BoundingBoxWidget'
import { ColorBalanceWidget } from './ColorBalanceWidget'
import { ColorCorrectWidget } from './ColorCorrectWidget'
import { ColorCurvesWidget } from './ColorCurvesWidget'
import { ButtonWidget } from './ButtonWidget'
import { ChartWidget } from './ChartWidget'
import { ColorWidget } from './ColorWidget'
@@ -54,6 +57,9 @@ export type WidgetTypeMap = {
asset: AssetWidget
imagecrop: ImageCropWidget
boundingbox: BoundingBoxWidget
colorcorrect: ColorCorrectWidget
colorbalance: ColorBalanceWidget
colorcurves: ColorCurvesWidget
[key: string]: BaseWidget
}
@@ -128,6 +134,12 @@ export function toConcreteWidget<TWidget extends IWidget | IBaseWidget>(
return toClass(ImageCropWidget, narrowedWidget, node)
case 'boundingbox':
return toClass(BoundingBoxWidget, narrowedWidget, node)
case 'colorcorrect':
return toClass(ColorCorrectWidget, narrowedWidget, node)
case 'colorbalance':
return toClass(ColorBalanceWidget, narrowedWidget, node)
case 'colorcurves':
return toClass(ColorCurvesWidget, narrowedWidget, node)
default: {
if (wrapLegacyWidgets) return toClass(LegacyWidget, widget, node)
}

View File

@@ -1824,6 +1824,38 @@
"width": "Width",
"height": "Height"
},
"colorCurves": {
"noInputImage": "No input image connected",
"reset": "Reset All",
"rgb": "RGB",
"red": "Red",
"green": "Green",
"blue": "Blue"
},
"colorBalance": {
"loading": "Loading...",
"noInputImage": "No input image connected",
"previewAlt": "Color balance preview",
"reset": "Reset All",
"shadows": "Shadows",
"midtones": "Midtones",
"highlights": "Highlights",
"cyanRed": "Cyan / Red",
"magentaGreen": "Magenta / Green",
"yellowBlue": "Yellow / Blue"
},
"colorCorrect": {
"loading": "Loading...",
"noInputImage": "No input image connected",
"previewAlt": "Color correction preview",
"reset": "Reset All",
"temperature": "Temperature",
"hue": "Hue",
"brightness": "Brightness",
"contrast": "Contrast",
"saturation": "Saturation",
"gamma": "Gamma"
},
"toastMessages": {
"nothingToQueue": "Nothing to queue",
"pleaseSelectOutputNodes": "Please select output nodes",

View File

@@ -0,0 +1,182 @@
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type {
ColorBalanceSettings,
IBaseWidget,
IColorBalanceWidget,
INumericWidget
} from '@/lib/litegraph/src/types/widgets'
import type {
ColorBalanceInputSpec,
InputSpec as InputSpecV2
} from '@/schemas/nodeDef/nodeDefSchemaV2'
import type { ComfyWidgetConstructorV2 } from '@/scripts/widgets'
interface FieldConfig {
key: keyof ColorBalanceSettings
min: number
max: number
step: number
step2: number
precision: number
}
const FIELDS: FieldConfig[] = [
{ key: 'shadows_red', min: -100, max: 100, step: 50, step2: 1, precision: 0 },
{
key: 'shadows_green',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'shadows_blue',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'midtones_red',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'midtones_green',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'midtones_blue',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'highlights_red',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'highlights_green',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
},
{
key: 'highlights_blue',
min: -100,
max: 100,
step: 50,
step2: 1,
precision: 0
}
]
const DEFAULT_SETTINGS: ColorBalanceSettings = {
shadows_red: 0,
shadows_green: 0,
shadows_blue: 0,
midtones_red: 0,
midtones_green: 0,
midtones_blue: 0,
highlights_red: 0,
highlights_green: 0,
highlights_blue: 0
}
function isColorBalanceWidget(
widget: IBaseWidget
): widget is IColorBalanceWidget {
return widget.type === 'colorbalance'
}
function isNumericWidget(widget: IBaseWidget): widget is INumericWidget {
return widget.type === 'number'
}
export const useColorBalanceWidget = (): ComfyWidgetConstructorV2 => {
return (node: LGraphNode, inputSpec: InputSpecV2): IColorBalanceWidget => {
const spec = inputSpec as ColorBalanceInputSpec
const { name } = spec
const defaultValue: ColorBalanceSettings = spec.default ?? {
...DEFAULT_SETTINGS
}
const subWidgets: INumericWidget[] = []
const rawWidget = node.addWidget(
'colorbalance',
name,
{ ...defaultValue },
null,
{
serialize: true,
canvasOnly: false
}
)
if (!isColorBalanceWidget(rawWidget)) {
throw new Error(`Unexpected widget type: ${rawWidget.type}`)
}
const widget = rawWidget
widget.callback = () => {
for (let i = 0; i < FIELDS.length; i++) {
const field = FIELDS[i]
const subWidget = subWidgets[i]
if (subWidget) {
subWidget.value = widget.value[field.key]
}
}
}
for (const field of FIELDS) {
const subWidget = node.addWidget(
'number',
field.key,
defaultValue[field.key],
function (this: INumericWidget, v: number) {
this.value = Math.round(v)
widget.value[field.key] = this.value
widget.callback?.(widget.value)
},
{
min: field.min,
max: field.max,
step: field.step,
step2: field.step2,
precision: field.precision,
serialize: false,
canvasOnly: true
}
)
if (!isNumericWidget(subWidget)) {
throw new Error(`Unexpected widget type: ${subWidget.type}`)
}
subWidgets.push(subWidget)
}
widget.linkedWidgets = subWidgets
return widget
}
}

View File

@@ -0,0 +1,121 @@
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type {
ColorCorrectSettings,
IBaseWidget,
IColorCorrectWidget,
INumericWidget
} from '@/lib/litegraph/src/types/widgets'
import type {
ColorCorrectInputSpec,
InputSpec as InputSpecV2
} from '@/schemas/nodeDef/nodeDefSchemaV2'
import type { ComfyWidgetConstructorV2 } from '@/scripts/widgets'
interface FieldConfig {
key: keyof ColorCorrectSettings
min: number
max: number
step: number
step2: number
precision: number
}
const FIELDS: FieldConfig[] = [
{ key: 'temperature', min: -100, max: 100, step: 50, step2: 5, precision: 0 },
{ key: 'hue', min: -90, max: 90, step: 50, step2: 5, precision: 0 },
{ key: 'brightness', min: -100, max: 100, step: 50, step2: 5, precision: 0 },
{ key: 'contrast', min: -100, max: 100, step: 50, step2: 5, precision: 0 },
{ key: 'saturation', min: -100, max: 100, step: 50, step2: 5, precision: 0 },
{ key: 'gamma', min: 0.2, max: 2.2, step: 1, step2: 0.1, precision: 1 }
]
const DEFAULT_SETTINGS: ColorCorrectSettings = {
temperature: 0,
hue: 0,
brightness: 0,
contrast: 0,
saturation: 0,
gamma: 1.0
}
function isColorCorrectWidget(
widget: IBaseWidget
): widget is IColorCorrectWidget {
return widget.type === 'colorcorrect'
}
function isNumericWidget(widget: IBaseWidget): widget is INumericWidget {
return widget.type === 'number'
}
export const useColorCorrectWidget = (): ComfyWidgetConstructorV2 => {
return (node: LGraphNode, inputSpec: InputSpecV2): IColorCorrectWidget => {
const spec = inputSpec as ColorCorrectInputSpec
const { name } = spec
const defaultValue: ColorCorrectSettings = spec.default ?? {
...DEFAULT_SETTINGS
}
const subWidgets: INumericWidget[] = []
const rawWidget = node.addWidget(
'colorcorrect',
name,
{ ...defaultValue },
null,
{
serialize: true,
canvasOnly: false
}
)
if (!isColorCorrectWidget(rawWidget)) {
throw new Error(`Unexpected widget type: ${rawWidget.type}`)
}
const widget = rawWidget
widget.callback = () => {
for (let i = 0; i < FIELDS.length; i++) {
const field = FIELDS[i]
const subWidget = subWidgets[i]
if (subWidget) {
subWidget.value = widget.value[field.key]
}
}
}
for (const field of FIELDS) {
const subWidget = node.addWidget(
'number',
field.key,
defaultValue[field.key],
function (this: INumericWidget, v: number) {
this.value =
field.precision === 0 ? Math.round(v) : Math.round(v * 10) / 10
widget.value[field.key] = this.value
widget.callback?.(widget.value)
},
{
min: field.min,
max: field.max,
step: field.step,
step2: field.step2,
precision: field.precision,
serialize: false,
canvasOnly: true
}
)
if (!isNumericWidget(subWidget)) {
throw new Error(`Unexpected widget type: ${subWidget.type}`)
}
subWidgets.push(subWidget)
}
widget.linkedWidgets = subWidgets
return widget
}
}

View File

@@ -0,0 +1,67 @@
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type {
ColorCurvesSettings,
IBaseWidget,
IColorCurvesWidget
} from '@/lib/litegraph/src/types/widgets'
import type {
ColorCurvesInputSpec,
InputSpec as InputSpecV2
} from '@/schemas/nodeDef/nodeDefSchemaV2'
import type { ComfyWidgetConstructorV2 } from '@/scripts/widgets'
const DEFAULT_SETTINGS: ColorCurvesSettings = {
rgb: [
[0, 0],
[1, 1]
],
red: [
[0, 0],
[1, 1]
],
green: [
[0, 0],
[1, 1]
],
blue: [
[0, 0],
[1, 1]
]
}
function isColorCurvesWidget(
widget: IBaseWidget
): widget is IColorCurvesWidget {
return widget.type === 'colorcurves'
}
export const useColorCurvesWidget = (): ComfyWidgetConstructorV2 => {
return (node: LGraphNode, inputSpec: InputSpecV2): IColorCurvesWidget => {
const spec = inputSpec as ColorCurvesInputSpec
const { name } = spec
const defaultValue: ColorCurvesSettings = spec.default ?? {
...DEFAULT_SETTINGS,
rgb: [...DEFAULT_SETTINGS.rgb],
red: [...DEFAULT_SETTINGS.red],
green: [...DEFAULT_SETTINGS.green],
blue: [...DEFAULT_SETTINGS.blue]
}
const rawWidget = node.addWidget(
'colorcurves',
name,
{ ...defaultValue },
null,
{
serialize: true,
canvasOnly: false
}
)
if (!isColorCurvesWidget(rawWidget)) {
throw new Error(`Unexpected widget type: ${rawWidget.type}`)
}
return rawWidget
}
}

View File

@@ -57,6 +57,15 @@ const WidgetImageCrop = defineAsyncComponent(
const WidgetBoundingBox = defineAsyncComponent(
() => import('@/components/boundingbox/WidgetBoundingBox.vue')
)
const WidgetColorBalance = defineAsyncComponent(
() => import('@/components/colorbalance/WidgetColorBalance.vue')
)
const WidgetColorCorrect = defineAsyncComponent(
() => import('@/components/colorcorrect/WidgetColorCorrect.vue')
)
const WidgetColorCurves = defineAsyncComponent(
() => import('@/components/colorcurves/WidgetColorCurves.vue')
)
export const FOR_TESTING = {
WidgetButton,
@@ -175,6 +184,30 @@ const coreWidgetDefinitions: Array<[string, WidgetDefinition]> = [
aliases: ['BOUNDING_BOX'],
essential: false
}
],
[
'colorbalance',
{
component: WidgetColorBalance,
aliases: ['COLOR_BALANCE'],
essential: false
}
],
[
'colorcorrect',
{
component: WidgetColorCorrect,
aliases: ['COLOR_CORRECT'],
essential: false
}
],
[
'colorcurves',
{
component: WidgetColorCurves,
aliases: ['COLOR_CURVES'],
essential: false
}
]
]
@@ -206,7 +239,14 @@ export const shouldRenderAsVue = (widget: Partial<SafeWidgetData>): boolean => {
return !widget.options?.canvasOnly && !!widget.type
}
const EXPANDING_TYPES = ['textarea', 'markdown', 'load3D'] as const
const EXPANDING_TYPES = [
'textarea',
'markdown',
'load3D',
'colorbalance',
'colorcorrect',
'colorcurves'
] as const
export function shouldExpand(type: string): boolean {
const canonicalType = getCanonicalType(type)

View File

@@ -126,6 +126,57 @@ const zTextareaInputSpec = zBaseInputOptions.extend({
.optional()
})
const zColorCorrectInputSpec = zBaseInputOptions.extend({
type: z.literal('COLOR_CORRECT'),
name: z.string(),
isOptional: z.boolean().optional(),
default: z
.object({
temperature: z.number(),
hue: z.number(),
brightness: z.number(),
contrast: z.number(),
saturation: z.number(),
gamma: z.number()
})
.optional()
})
const zColorBalanceInputSpec = zBaseInputOptions.extend({
type: z.literal('COLOR_BALANCE'),
name: z.string(),
isOptional: z.boolean().optional(),
default: z
.object({
shadows_red: z.number(),
shadows_green: z.number(),
shadows_blue: z.number(),
midtones_red: z.number(),
midtones_green: z.number(),
midtones_blue: z.number(),
highlights_red: z.number(),
highlights_green: z.number(),
highlights_blue: z.number()
})
.optional()
})
const zCurvePoint = z.tuple([z.number(), z.number()])
const zColorCurvesInputSpec = zBaseInputOptions.extend({
type: z.literal('COLOR_CURVES'),
name: z.string(),
isOptional: z.boolean().optional(),
default: z
.object({
rgb: z.array(zCurvePoint),
red: z.array(zCurvePoint),
green: z.array(zCurvePoint),
blue: z.array(zCurvePoint)
})
.optional()
})
const zCustomInputSpec = zBaseInputOptions.extend({
type: z.string(),
name: z.string(),
@@ -142,6 +193,9 @@ const zInputSpec = z.union([
zImageInputSpec,
zImageCompareInputSpec,
zBoundingBoxInputSpec,
zColorCorrectInputSpec,
zColorBalanceInputSpec,
zColorCurvesInputSpec,
zMarkdownInputSpec,
zChartInputSpec,
zGalleriaInputSpec,
@@ -187,6 +241,9 @@ export type ComboInputSpec = z.infer<typeof zComboInputSpec>
export type ColorInputSpec = z.infer<typeof zColorInputSpec>
export type ImageCompareInputSpec = z.infer<typeof zImageCompareInputSpec>
export type BoundingBoxInputSpec = z.infer<typeof zBoundingBoxInputSpec>
export type ColorCorrectInputSpec = z.infer<typeof zColorCorrectInputSpec>
export type ColorBalanceInputSpec = z.infer<typeof zColorBalanceInputSpec>
export type ColorCurvesInputSpec = z.infer<typeof zColorCurvesInputSpec>
export type ChartInputSpec = z.infer<typeof zChartInputSpec>
export type GalleriaInputSpec = z.infer<typeof zGalleriaInputSpec>
export type TextareaInputSpec = z.infer<typeof zTextareaInputSpec>

View File

@@ -9,6 +9,9 @@ import { useSettingStore } from '@/platform/settings/settingStore'
import { dynamicWidgets } from '@/core/graph/widgets/dynamicWidgets'
import { useBooleanWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget'
import { useBoundingBoxWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget'
import { useColorBalanceWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useColorBalanceWidget'
import { useColorCorrectWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useColorCorrectWidget'
import { useColorCurvesWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useColorCurvesWidget'
import { useChartWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useChartWidget'
import { useColorWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useColorWidget'
import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget'
@@ -303,6 +306,9 @@ export const ComfyWidgets = {
COLOR: transformWidgetConstructorV2ToV1(useColorWidget()),
IMAGECOMPARE: transformWidgetConstructorV2ToV1(useImageCompareWidget()),
BOUNDING_BOX: transformWidgetConstructorV2ToV1(useBoundingBoxWidget()),
COLOR_CORRECT: transformWidgetConstructorV2ToV1(useColorCorrectWidget()),
COLOR_BALANCE: transformWidgetConstructorV2ToV1(useColorBalanceWidget()),
COLOR_CURVES: transformWidgetConstructorV2ToV1(useColorCurvesWidget()),
CHART: transformWidgetConstructorV2ToV1(useChartWidget()),
GALLERIA: transformWidgetConstructorV2ToV1(useGalleriaWidget()),
TEXTAREA: transformWidgetConstructorV2ToV1(useTextareaWidget()),