mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-03 13:48:49 +00:00
Compare commits
6 Commits
shihchi/co
...
codex/cove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
caa535e906 | ||
|
|
3a9470018f | ||
|
|
08da6fae72 | ||
|
|
2ec2a0e091 | ||
|
|
9cf5c9a93f | ||
|
|
9e5fb67b76 |
@@ -15,7 +15,7 @@ const { categories } = defineProps<{
|
||||
|
||||
const activeSection = ref(categories[0]?.value ?? '')
|
||||
|
||||
const HEADER_OFFSET = -144
|
||||
const HEADER_OFFSET_PX = -144
|
||||
const BOTTOM_THRESHOLD_PX = 4
|
||||
const SCROLL_SAFETY_MS = 1500
|
||||
|
||||
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
|
||||
const el = document.getElementById(id)
|
||||
if (el) {
|
||||
scrollTo(el, {
|
||||
offset: HEADER_OFFSET,
|
||||
offset: HEADER_OFFSET_PX,
|
||||
duration: 0.8,
|
||||
immediate: prefersReducedMotion(),
|
||||
onComplete: clearScrollLock
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<li
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
|
||||
>
|
||||
<slot />
|
||||
</li>
|
||||
|
||||
45
browser_tests/assets/linear-validation-warning.json
Normal file
45
browser_tests/assets/linear-validation-warning.json
Normal file
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"last_node_id": 9,
|
||||
"last_link_id": 9,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 9,
|
||||
"type": "SaveImage",
|
||||
"pos": {
|
||||
"0": 64,
|
||||
"1": 104
|
||||
},
|
||||
"size": {
|
||||
"0": 210,
|
||||
"1": 58
|
||||
},
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"properties": {},
|
||||
"widgets_values": ["ComfyUI"]
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 1,
|
||||
"offset": [0, 0]
|
||||
},
|
||||
"linearData": {
|
||||
"inputs": [],
|
||||
"outputs": ["9"]
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
@@ -34,6 +34,10 @@ export class AppModeHelper {
|
||||
public readonly outputPlaceholder: Locator
|
||||
/** The linear-mode widget list container (visible in app mode). */
|
||||
public readonly linearWidgets: Locator
|
||||
/** The validation warning shown above the app mode run button. */
|
||||
public readonly validationWarning: Locator
|
||||
/** The action that opens graph mode errors from the validation warning. */
|
||||
public readonly viewErrorsInGraphButton: Locator
|
||||
/** The PrimeVue Popover for the image picker (renders with role="dialog"). */
|
||||
public readonly imagePickerPopover: Locator
|
||||
/** The Run button in the app mode footer. */
|
||||
@@ -92,13 +96,19 @@ export class AppModeHelper {
|
||||
this.outputPlaceholder = this.page.getByTestId(
|
||||
TestIds.builder.outputPlaceholder
|
||||
)
|
||||
this.linearWidgets = this.page.getByTestId('linear-widgets')
|
||||
this.linearWidgets = this.page.getByTestId(TestIds.linear.widgetContainer)
|
||||
this.validationWarning = this.page.getByTestId(
|
||||
TestIds.linear.validationWarning
|
||||
)
|
||||
this.viewErrorsInGraphButton = this.validationWarning.getByTestId(
|
||||
TestIds.linear.viewErrorsInGraph
|
||||
)
|
||||
this.imagePickerPopover = this.page
|
||||
.getByRole('dialog')
|
||||
.filter({ has: this.page.getByRole('button', { name: 'All' }) })
|
||||
.first()
|
||||
this.runButton = this.page
|
||||
.getByTestId('linear-run-button')
|
||||
.getByTestId(TestIds.linear.runButton)
|
||||
.getByRole('button', { name: /run/i })
|
||||
this.welcome = this.page.getByTestId(TestIds.appMode.welcome)
|
||||
this.emptyWorkflowText = this.page.getByTestId(
|
||||
|
||||
@@ -172,6 +172,9 @@ export const TestIds = {
|
||||
mobileNavigation: 'linear-mobile-navigation',
|
||||
mobileWorkflows: 'linear-mobile-workflows',
|
||||
outputInfo: 'linear-output-info',
|
||||
runButton: 'linear-run-button',
|
||||
validationWarning: 'linear-validation-warning',
|
||||
viewErrorsInGraph: 'linear-view-errors',
|
||||
widgetContainer: 'linear-widgets'
|
||||
},
|
||||
builder: {
|
||||
|
||||
106
browser_tests/tests/appModeValidationWarning.spec.ts
Normal file
106
browser_tests/tests/appModeValidationWarning.spec.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import {
|
||||
comfyExpect as expect,
|
||||
comfyPageFixture as test
|
||||
} from '@e2e/fixtures/ComfyPage'
|
||||
import type { NodeError, PromptResponse } from '@/schemas/apiSchema'
|
||||
import { ExecutionHelper } from '@e2e/fixtures/helpers/ExecutionHelper'
|
||||
import { enableErrorsOverlay } from '@e2e/fixtures/helpers/ErrorsTabHelper'
|
||||
import { TestIds } from '@e2e/fixtures/selectors'
|
||||
|
||||
const SAVE_IMAGE_NODE_ID = '9'
|
||||
|
||||
function buildSaveImageRequiredInputError(): NodeError {
|
||||
return {
|
||||
class_type: 'SaveImage',
|
||||
dependent_outputs: [],
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Required input is missing: images',
|
||||
details: '',
|
||||
extra_info: { input_name: 'images' }
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
test.describe(
|
||||
'App mode validation warning',
|
||||
{ tag: ['@ui', '@workflow'] },
|
||||
() => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await enableErrorsOverlay(comfyPage)
|
||||
await comfyPage.workflow.loadWorkflow('linear-validation-warning')
|
||||
await comfyPage.appMode.toggleAppMode()
|
||||
await expect(comfyPage.appMode.linearWidgets).toBeVisible()
|
||||
})
|
||||
|
||||
test('opens graph errors from the app mode validation warning', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
await expect(comfyPage.appMode.validationWarning).toBeHidden()
|
||||
|
||||
const exec = new ExecutionHelper(comfyPage)
|
||||
await exec.mockValidationFailure({
|
||||
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
|
||||
})
|
||||
|
||||
await comfyPage.appMode.runButton.click()
|
||||
const appModeOverlay = comfyPage.appMode.centerPanel.getByTestId(
|
||||
TestIds.dialogs.errorOverlay
|
||||
)
|
||||
await expect(appModeOverlay).toBeHidden()
|
||||
|
||||
await expect(comfyPage.appMode.validationWarning).toBeVisible()
|
||||
await expect(comfyPage.appMode.validationWarning).toContainText(
|
||||
/Required input missing/i
|
||||
)
|
||||
await expect(comfyPage.appMode.viewErrorsInGraphButton).toBeVisible()
|
||||
|
||||
await comfyPage.appMode.viewErrorsInGraphButton.click()
|
||||
|
||||
await expect(comfyPage.appMode.linearWidgets).toBeHidden()
|
||||
await expect(
|
||||
comfyPage.page.getByTestId(TestIds.propertiesPanel.root)
|
||||
).toBeVisible()
|
||||
await expect(
|
||||
comfyPage.page.getByTestId(TestIds.propertiesPanel.errorsTab)
|
||||
).toBeVisible()
|
||||
})
|
||||
|
||||
test('keeps the app mode run button enabled when the warning is visible', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const exec = new ExecutionHelper(comfyPage)
|
||||
await exec.mockValidationFailure({
|
||||
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
|
||||
})
|
||||
|
||||
await comfyPage.appMode.runButton.click()
|
||||
await expect(comfyPage.appMode.validationWarning).toBeVisible()
|
||||
await expect(comfyPage.appMode.runButton).toBeEnabled()
|
||||
|
||||
let promptQueued = false
|
||||
const mockResponse: PromptResponse = {
|
||||
prompt_id: 'test-id',
|
||||
node_errors: {},
|
||||
error: ''
|
||||
}
|
||||
await comfyPage.page.route(
|
||||
'**/api/prompt',
|
||||
async (route) => {
|
||||
promptQueued = true
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
body: JSON.stringify(mockResponse)
|
||||
})
|
||||
},
|
||||
{ times: 1 }
|
||||
)
|
||||
|
||||
await comfyPage.appMode.runButton.click()
|
||||
|
||||
await expect.poll(() => promptQueued).toBe(true)
|
||||
})
|
||||
}
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage'
|
||||
@@ -15,9 +16,10 @@ test.describe('Graph', { tag: ['@smoke', '@canvas'] }, () => {
|
||||
await comfyPage.workflow.loadWorkflow('inputs/input_order_swap')
|
||||
await expect
|
||||
.poll(() =>
|
||||
comfyPage.page.evaluate(() => {
|
||||
return window.app!.graph!.links.get(1)?.target_slot
|
||||
})
|
||||
comfyPage.page.evaluate(
|
||||
(linkId) => window.app!.graph!.links.get(linkId)?.target_slot,
|
||||
toLinkId(1)
|
||||
)
|
||||
)
|
||||
.toBe(1)
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
comfyPageFixture as test,
|
||||
comfyExpect as expect
|
||||
} from '@e2e/fixtures/ComfyPage'
|
||||
import { TestIds } from '@e2e/fixtures/selectors'
|
||||
|
||||
test.describe('Linear Mode', { tag: '@ui' }, () => {
|
||||
test('Displays linear controls when app mode active', async ({
|
||||
@@ -16,7 +17,9 @@ test.describe('Linear Mode', { tag: '@ui' }, () => {
|
||||
test('Run button visible in linear mode', async ({ comfyPage }) => {
|
||||
await comfyPage.appMode.enterAppModeWithInputs([])
|
||||
|
||||
await expect(comfyPage.page.getByTestId('linear-run-button')).toBeVisible()
|
||||
await expect(
|
||||
comfyPage.page.getByTestId(TestIds.linear.runButton)
|
||||
).toBeVisible()
|
||||
})
|
||||
|
||||
test('Workflow info section visible', async ({ comfyPage }) => {
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
size="unset"
|
||||
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
|
||||
data-testid="error-overlay-see-errors"
|
||||
@click="seeErrors"
|
||||
@click="viewErrorsInGraph"
|
||||
>
|
||||
{{
|
||||
appMode
|
||||
@@ -67,31 +67,18 @@ import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
|
||||
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
|
||||
|
||||
const { appMode = false } = defineProps<{ appMode?: boolean }>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
const canvasStore = useCanvasStore()
|
||||
const { viewErrorsInGraph } = useViewErrorsInGraph()
|
||||
|
||||
const { isVisible, overlayMessage, overlayTitle } = useErrorOverlayState()
|
||||
|
||||
function dismiss() {
|
||||
executionErrorStore.dismissErrorOverlay()
|
||||
}
|
||||
|
||||
function seeErrors() {
|
||||
canvasStore.linearMode = false
|
||||
if (canvasStore.canvas) {
|
||||
canvasStore.canvas.deselectAll()
|
||||
canvasStore.updateSelectedItems()
|
||||
}
|
||||
|
||||
rightSidePanelStore.openPanel('errors')
|
||||
executionErrorStore.dismissErrorOverlay()
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
|
||||
const subscriptionDialog = useSubscriptionDialog()
|
||||
|
||||
function handleClick() {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ComputedRef, Ref } from 'vue'
|
||||
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type {
|
||||
BillingStatus,
|
||||
@@ -75,9 +76,10 @@ export interface BillingActions {
|
||||
*/
|
||||
requireActiveSubscription: () => Promise<void>
|
||||
/**
|
||||
* Shows the subscription dialog.
|
||||
* Shows the subscription dialog. Pass a reason so the paywall open and any
|
||||
* downstream checkout stay attributed to the triggering product moment.
|
||||
*/
|
||||
showSubscriptionDialog: () => void
|
||||
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
|
||||
}
|
||||
|
||||
export interface BillingState {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
getTierFeatures
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
PreviewSubscribeOptions,
|
||||
SubscribeOptions
|
||||
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
|
||||
return activeContext.value.requireActiveSubscription()
|
||||
}
|
||||
|
||||
function showSubscriptionDialog() {
|
||||
return activeContext.value.showSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
|
||||
return activeContext.value.showSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
|
||||
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingStatus,
|
||||
BillingSubscriptionStatus,
|
||||
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
legacyShowSubscriptionDialog()
|
||||
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
legacyShowSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
legacyShowSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
105
src/composables/useViewErrorsInGraph.test.ts
Normal file
105
src/composables/useViewErrorsInGraph.test.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
import { LGraph, LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { createMockCanvasRenderingContext2D } from '@/utils/__tests__/litegraphTestUtils'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
|
||||
import { useViewErrorsInGraph } from './useViewErrorsInGraph'
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getSettings: vi.fn(),
|
||||
storeSetting: vi.fn(),
|
||||
storeSettings: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: apiMock
|
||||
}))
|
||||
|
||||
const appMock = vi.hoisted(() => ({
|
||||
ui: {
|
||||
settings: {
|
||||
dispatchChange: vi.fn()
|
||||
}
|
||||
},
|
||||
rootGraph: {
|
||||
events: new EventTarget(),
|
||||
nodes: []
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: appMock
|
||||
}))
|
||||
|
||||
function createSelectedCanvas() {
|
||||
const graph = new LGraph()
|
||||
const canvasElement = document.createElement('canvas')
|
||||
canvasElement.width = 800
|
||||
canvasElement.height = 600
|
||||
canvasElement.getContext = vi
|
||||
.fn()
|
||||
.mockReturnValue(createMockCanvasRenderingContext2D())
|
||||
|
||||
const canvas = new LGraphCanvas(canvasElement, graph, {
|
||||
skip_events: true,
|
||||
skip_render: true
|
||||
})
|
||||
const node = new LGraphNode('Selected Node')
|
||||
graph.add(node)
|
||||
canvas.selectedItems.add(node)
|
||||
node.selected = true
|
||||
|
||||
return { canvas, node }
|
||||
}
|
||||
|
||||
describe('useViewErrorsInGraph', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setActivePinia(createPinia())
|
||||
apiMock.getSettings.mockResolvedValue({})
|
||||
apiMock.storeSetting.mockResolvedValue(undefined)
|
||||
apiMock.storeSettings.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('opens graph errors and clears app-mode error UI state', () => {
|
||||
const canvasStore = useCanvasStore()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { canvas, node } = createSelectedCanvas()
|
||||
workflowStore.activeWorkflow = {
|
||||
activeMode: 'app'
|
||||
} as typeof workflowStore.activeWorkflow
|
||||
canvasStore.canvas = canvas
|
||||
canvasStore.selectedItems = [node]
|
||||
executionErrorStore.showErrorOverlay()
|
||||
|
||||
useViewErrorsInGraph().viewErrorsInGraph()
|
||||
|
||||
expect(node.selected).toBe(false)
|
||||
expect(canvasStore.linearMode).toBe(false)
|
||||
expect(canvasStore.selectedItems).toEqual([])
|
||||
expect(rightSidePanelStore.activeTab).toBe('errors')
|
||||
expect(rightSidePanelStore.isOpen).toBe(true)
|
||||
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('opens graph errors when the canvas is not initialized', () => {
|
||||
const canvasStore = useCanvasStore()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
canvasStore.canvas = null
|
||||
executionErrorStore.showErrorOverlay()
|
||||
|
||||
expect(() => useViewErrorsInGraph().viewErrorsInGraph()).not.toThrow()
|
||||
|
||||
expect(rightSidePanelStore.activeTab).toBe('errors')
|
||||
expect(rightSidePanelStore.isOpen).toBe(true)
|
||||
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
})
|
||||
22
src/composables/useViewErrorsInGraph.ts
Normal file
22
src/composables/useViewErrorsInGraph.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
|
||||
export function useViewErrorsInGraph() {
|
||||
const canvasStore = useCanvasStore()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
|
||||
function viewErrorsInGraph() {
|
||||
canvasStore.linearMode = false
|
||||
if (canvasStore.canvas) {
|
||||
canvasStore.canvas.deselectAll()
|
||||
canvasStore.updateSelectedItems()
|
||||
}
|
||||
|
||||
rightSidePanelStore.openPanel('errors')
|
||||
executionErrorStore.dismissErrorOverlay()
|
||||
}
|
||||
|
||||
return { viewErrorsInGraph }
|
||||
}
|
||||
@@ -16,12 +16,14 @@ import {
|
||||
} from '@/lib/litegraph/src/subgraph/__fixtures__/subgraphHelpers'
|
||||
|
||||
import {
|
||||
appendQuarantine,
|
||||
flushProxyWidgetMigration,
|
||||
normalizeLegacyProxyWidgetEntry,
|
||||
readHostQuarantine
|
||||
} from '@/core/graph/subgraph/migration/proxyWidgetMigration'
|
||||
import { usePreviewExposureStore } from '@/stores/previewExposureStore'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { UNASSIGNED_NODE_ID, toNodeId } from '@/types/nodeId'
|
||||
import { useWidgetValueStore } from '@/stores/widgetValueStore'
|
||||
|
||||
vi.mock('@/renderer/core/canvas/canvasStore', () => ({
|
||||
@@ -179,6 +181,33 @@ describe('flushProxyWidgetMigration', () => {
|
||||
expect(getPromotedInputValue(outerHost, 'text')).toBe('22222222222')
|
||||
})
|
||||
|
||||
it('createSubgraphInput: resolves a nested promoted input by host input name', () => {
|
||||
const rootGraph = new LGraph()
|
||||
const innerSubgraph = createTestSubgraph({ rootGraph })
|
||||
const source = new LGraphNode('CLIPTextEncode')
|
||||
const sourceSlot = source.addInput('text', 'STRING')
|
||||
sourceSlot.widget = { name: 'text' }
|
||||
source.addWidget('text', 'text', 'nested value', () => {})
|
||||
innerSubgraph.add(source)
|
||||
|
||||
const nestedHost = createTestSubgraphNode(innerSubgraph, {
|
||||
parentGraph: rootGraph
|
||||
})
|
||||
nestedHost.properties.proxyWidgets = [[String(source.id), 'text']]
|
||||
flushProxyWidgetMigration({ hostNode: nestedHost })
|
||||
|
||||
const outerSubgraph = createTestSubgraph({ rootGraph })
|
||||
outerSubgraph.add(nestedHost)
|
||||
const outerHost = createTestSubgraphNode(outerSubgraph, {
|
||||
parentGraph: rootGraph
|
||||
})
|
||||
outerHost.properties.proxyWidgets = [[String(nestedHost.id), 'text']]
|
||||
|
||||
flushProxyWidgetMigration({ hostNode: outerHost })
|
||||
|
||||
expect(getPromotedInputValue(outerHost, 'text')).toBe('nested value')
|
||||
})
|
||||
|
||||
it('alreadyLinked: leaves widget value unchanged when host value is a sparse hole', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'seed', type: 'INT' }]
|
||||
@@ -240,6 +269,41 @@ describe('flushProxyWidgetMigration', () => {
|
||||
).toBe('renamed_from_sidepanel')
|
||||
})
|
||||
|
||||
it('createSubgraphInput: falls back to the source widget type when the slot type is missing', () => {
|
||||
const host = buildHost()
|
||||
const inner = addInnerNode(host, 'Inner', (n) => {
|
||||
const slot = n.addInput('seed', 'INT')
|
||||
slot.type = undefined as never
|
||||
slot.widget = { name: 'seed' }
|
||||
n.addWidget('number', 'seed', 0, () => {})
|
||||
})
|
||||
|
||||
host.properties.proxyWidgets = [[String(inner.id), 'seed']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(
|
||||
host.subgraph.inputs.find((input) => input.name === 'seed')?.type
|
||||
).toBe('number')
|
||||
})
|
||||
|
||||
it('createSubgraphInput: falls back to wildcard type when slot and widget type are missing', () => {
|
||||
const host = buildHost()
|
||||
const inner = addInnerNode(host, 'Inner', (n) => {
|
||||
const slot = n.addInput('seed', 'INT')
|
||||
slot.type = undefined as never
|
||||
slot.widget = { name: 'seed' }
|
||||
const widget = n.addWidget('number', 'seed', 0, () => {})
|
||||
widget.type = undefined as never
|
||||
})
|
||||
|
||||
host.properties.proxyWidgets = [[String(inner.id), 'seed']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(
|
||||
host.subgraph.inputs.find((input) => input.name === 'seed')?.type
|
||||
).toBe('*')
|
||||
})
|
||||
|
||||
it('createSubgraphInput: quarantines missingSubgraphInput when source widget has no backing input slot', () => {
|
||||
const host = buildHost()
|
||||
const inner = addInnerNode(host, 'Inner', (n) => {
|
||||
@@ -328,6 +392,88 @@ describe('flushProxyWidgetMigration', () => {
|
||||
expect(getPromotedInputValue(host, 'value')).toBe(11)
|
||||
})
|
||||
|
||||
it('uses the primitive title as the promoted input name when it was renamed', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
primitive.title = 'Batch Size'
|
||||
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(
|
||||
host.inputs.find((input) => input.name === 'Batch Size')
|
||||
).toBeDefined()
|
||||
})
|
||||
|
||||
it('skips a stale primitive bypass marker when the host input is absent', () => {
|
||||
const host = buildHost()
|
||||
const { primitive, targets } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
primitive.properties = {
|
||||
proxyBypassedToSubgraphInput: 'deleted_input'
|
||||
}
|
||||
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
const slot = targets[0].inputs[0]
|
||||
const link = host.subgraph.links.get(slot.link!)
|
||||
expect(link?.origin_id).not.toBe(primitive.id)
|
||||
expect(host.inputs.find((input) => input.name === 'value')).toBeDefined()
|
||||
})
|
||||
|
||||
it('quarantines a stale primitive bypass marker that points to a plain input', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
primitive.properties = {
|
||||
proxyBypassedToSubgraphInput: 'plain'
|
||||
}
|
||||
host.addInput('plain', 'INT')
|
||||
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({
|
||||
hostNode: host,
|
||||
hostWidgetValues: [12]
|
||||
})
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'missingSubgraphInput'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines a stale primitive bypass marker that matches ambiguous host inputs', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
primitive.properties = {
|
||||
proxyBypassedToSubgraphInput: 'plain'
|
||||
}
|
||||
host.addInput('plain', 'INT')
|
||||
host.addInput('plain', 'INT')
|
||||
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({
|
||||
hostNode: host,
|
||||
hostWidgetValues: [12]
|
||||
})
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'ambiguousSubgraphInput'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines an unlinked primitive node with no fan-out', () => {
|
||||
const host = buildHost()
|
||||
const primitive = new LGraphNode('Primitive')
|
||||
@@ -346,6 +492,64 @@ describe('flushProxyWidgetMigration', () => {
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines primitive cohorts that disagree on source widget name', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
|
||||
host.properties.proxyWidgets = [
|
||||
[String(primitive.id), 'value'],
|
||||
[String(primitive.id), 'other']
|
||||
]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'primitiveBypassFailed'
|
||||
}),
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'other'],
|
||||
reason: 'primitiveBypassFailed'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines duplicate primitive entries with no fan-out targets', () => {
|
||||
const host = buildHost()
|
||||
const primitive = new LGraphNode('PrimitiveNode')
|
||||
primitive.type = 'PrimitiveNode'
|
||||
primitive.addOutput('value', 'INT')
|
||||
host.subgraph.add(primitive)
|
||||
|
||||
host.properties.proxyWidgets = [
|
||||
[String(primitive.id), 'value'],
|
||||
[String(primitive.id), 'value']
|
||||
]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'primitiveBypassFailed'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('keeps the target default when the primitive source widget has no value', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
primitive.widgets = []
|
||||
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(getPromotedInputValue(host, 'value')).toBe(0)
|
||||
})
|
||||
|
||||
it('quarantines all cohort entries when a target slot type is incompatible', () => {
|
||||
const host = buildHost()
|
||||
const { primitive, targets } = addPrimitiveWithTargets(host, {
|
||||
@@ -366,6 +570,73 @@ describe('flushProxyWidgetMigration', () => {
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines primitive repair when the target slot disappeared', () => {
|
||||
const host = buildHost()
|
||||
const { primitive, targets } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
targets[0].inputs = []
|
||||
|
||||
const inputCountBefore = host.subgraph.inputs.length
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(host.subgraph.inputs).toHaveLength(inputCountBefore)
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'primitiveBypassFailed'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines primitive repair when the target node id is stale', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
const linkId = primitive.outputs[0].links?.[0]
|
||||
if (!linkId) throw new Error('Missing primitive link')
|
||||
const link = host.subgraph.links.get(linkId)
|
||||
if (!link) throw new Error('Missing primitive link record')
|
||||
link.target_id = toNodeId(999_999)
|
||||
|
||||
host.properties.proxyWidgets = [[String(primitive.id), 'value']]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'primitiveBypassFailed'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('quarantines duplicate primitive entries when the fan-out target is unassigned', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, {
|
||||
targetCount: 1
|
||||
})
|
||||
const linkId = primitive.outputs[0].links?.[0]
|
||||
if (!linkId) throw new Error('Missing primitive link')
|
||||
const link = host.subgraph.links.get(linkId)
|
||||
if (!link) throw new Error('Missing primitive link record')
|
||||
link.target_id = UNASSIGNED_NODE_ID
|
||||
|
||||
host.properties.proxyWidgets = [
|
||||
[String(primitive.id), 'value'],
|
||||
[String(primitive.id), 'value']
|
||||
]
|
||||
flushProxyWidgetMigration({ hostNode: host })
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.objectContaining({
|
||||
originalEntry: [String(primitive.id), 'value'],
|
||||
reason: 'primitiveBypassFailed'
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('keeps surviving primitive targets when one fan-out link is dangling', () => {
|
||||
const host = buildHost()
|
||||
const { primitive } = addPrimitiveWithTargets(host, { targetCount: 1 })
|
||||
@@ -572,6 +843,22 @@ describe('flushProxyWidgetMigration', () => {
|
||||
])
|
||||
})
|
||||
|
||||
it('does not preserve non-widget host values on quarantine rows', () => {
|
||||
const host = buildHost()
|
||||
host.properties.proxyWidgets = [['9999', 'seed']]
|
||||
|
||||
flushProxyWidgetMigration({
|
||||
hostNode: host,
|
||||
hostWidgetValues: [null]
|
||||
})
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual([
|
||||
expect.not.objectContaining({
|
||||
hostValue: expect.anything()
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('round-trips appended entries via the public read helper', () => {
|
||||
const host = buildHost()
|
||||
host.properties.proxyWidgets = [['9999', 'seed']]
|
||||
@@ -602,6 +889,14 @@ describe('flushProxyWidgetMigration', () => {
|
||||
|
||||
expect(readHostQuarantine(host)).toEqual(firstQuarantine)
|
||||
})
|
||||
|
||||
it('ignores empty quarantine append requests', () => {
|
||||
const host = buildHost()
|
||||
|
||||
appendQuarantine(host, [])
|
||||
|
||||
expect(host.properties.proxyWidgetErrorQuarantine).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('idempotency', () => {
|
||||
@@ -824,6 +1119,22 @@ describe('normalizeLegacyProxyWidgetEntry', () => {
|
||||
expect(result.disambiguatingSourceNodeId).toBe(String(samplerNode.id))
|
||||
})
|
||||
|
||||
it('strips nested legacy prefixes from widget name', () => {
|
||||
const { hostNode, innerNode } = createHostWithInnerWidget('seed')
|
||||
|
||||
const result = normalizeLegacyProxyWidgetEntry(
|
||||
hostNode,
|
||||
String(innerNode.id),
|
||||
'111: 222: seed'
|
||||
)
|
||||
|
||||
expect(result).toEqual({
|
||||
sourceNodeId: String(innerNode.id),
|
||||
sourceWidgetName: 'seed',
|
||||
disambiguatingSourceNodeId: '222'
|
||||
})
|
||||
})
|
||||
|
||||
it('strips legacy prefix and surfaces it as disambiguator even when the bare name does not resolve', () => {
|
||||
const { hostNode, innerNode } = createHostWithInnerWidget('seed')
|
||||
|
||||
|
||||
179
src/core/graph/subgraph/promotedInputWidget.test.ts
Normal file
179
src/core/graph/subgraph/promotedInputWidget.test.ts
Normal file
@@ -0,0 +1,179 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { INodeInputSlot } from '@/lib/litegraph/src/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import type { WidgetId } from '@/types/widgetId'
|
||||
|
||||
import {
|
||||
inputForWidget,
|
||||
promotedInputSource,
|
||||
promotedInputWidget,
|
||||
promotedInputWidgets,
|
||||
widgetPromotedSource
|
||||
} from './promotedInputWidget'
|
||||
import { resolveSubgraphInputTarget } from './resolveSubgraphInputTarget'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
widgets: new Map<string, Record<string, unknown>>(),
|
||||
setValue: vi.fn(),
|
||||
resolveSubgraphInputTarget: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/widgetValueStore', () => ({
|
||||
useWidgetValueStore: () => ({
|
||||
getWidget: (id: string) => mocks.widgets.get(id),
|
||||
setValue: mocks.setValue
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('./resolveSubgraphInputTarget', () => ({
|
||||
resolveSubgraphInputTarget: mocks.resolveSubgraphInputTarget
|
||||
}))
|
||||
|
||||
function input(overrides: Partial<INodeInputSlot> = {}): INodeInputSlot {
|
||||
return {
|
||||
name: 'prompt',
|
||||
type: 'STRING',
|
||||
label: 'Prompt',
|
||||
...overrides
|
||||
} as INodeInputSlot
|
||||
}
|
||||
|
||||
function node(overrides: Record<string, unknown> = {}): LGraphNode {
|
||||
return {
|
||||
inputs: [],
|
||||
isSubgraphNode: () => true,
|
||||
getSlotFromWidget: vi.fn(),
|
||||
...overrides
|
||||
} as unknown as LGraphNode
|
||||
}
|
||||
|
||||
describe('promotedInputWidget helpers', () => {
|
||||
beforeEach(() => {
|
||||
mocks.widgets.clear()
|
||||
mocks.setValue.mockClear()
|
||||
mocks.resolveSubgraphInputTarget.mockReset()
|
||||
})
|
||||
|
||||
it('resolves promoted input sources only for widget-backed inputs', () => {
|
||||
const graphNode = node()
|
||||
mocks.resolveSubgraphInputTarget.mockReturnValue({
|
||||
nodeId: '12',
|
||||
widgetName: 'prompt'
|
||||
})
|
||||
|
||||
expect(promotedInputSource(graphNode, input())).toBeUndefined()
|
||||
expect(
|
||||
promotedInputSource(
|
||||
graphNode,
|
||||
input({ widgetId: 'graph:12:prompt' as WidgetId })
|
||||
)
|
||||
).toEqual({
|
||||
nodeId: '12',
|
||||
widgetName: 'prompt'
|
||||
})
|
||||
expect(resolveSubgraphInputTarget).toHaveBeenCalledWith(graphNode, 'prompt')
|
||||
})
|
||||
|
||||
it('resolves promoted widget sources only on subgraph nodes with matching inputs', () => {
|
||||
const widget = { name: 'prompt' } as IBaseWidget
|
||||
const backingInput = input({ widgetId: 'graph:12:prompt' as WidgetId })
|
||||
mocks.resolveSubgraphInputTarget.mockReturnValue({
|
||||
nodeId: '12',
|
||||
widgetName: 'prompt'
|
||||
})
|
||||
|
||||
expect(
|
||||
widgetPromotedSource(node({ isSubgraphNode: () => false }), widget)
|
||||
).toBeUndefined()
|
||||
expect(
|
||||
widgetPromotedSource(node({ getSlotFromWidget: () => undefined }), widget)
|
||||
).toBeUndefined()
|
||||
expect(
|
||||
widgetPromotedSource(
|
||||
node({ getSlotFromWidget: () => backingInput }),
|
||||
widget
|
||||
)
|
||||
).toEqual({
|
||||
nodeId: '12',
|
||||
widgetName: 'prompt'
|
||||
})
|
||||
})
|
||||
|
||||
it('projects store-backed widget fields with input fallbacks', () => {
|
||||
const widgetId = 'graph:12:prompt' as WidgetId
|
||||
const widget = promotedInputWidget(input({ widgetId }))
|
||||
|
||||
expect(widget?.name).toBe('prompt')
|
||||
expect(widget?.label).toBe('Prompt')
|
||||
expect(widget?.y).toBe(0)
|
||||
expect(widget?.type).toBe('text')
|
||||
expect(widget?.options).toEqual({})
|
||||
expect(widget?.value).toBeUndefined()
|
||||
|
||||
widget!.label = 'Ignored'
|
||||
widget!.y = 12
|
||||
widget!.value = 'next'
|
||||
widget!.callback?.('callback')
|
||||
|
||||
expect(mocks.setValue).toHaveBeenCalledWith(widgetId, 'next')
|
||||
expect(mocks.setValue).toHaveBeenCalledWith(widgetId, 'callback')
|
||||
})
|
||||
|
||||
it('projects live widget store fields and mutates store state', () => {
|
||||
const widgetId = 'graph:12:prompt' as WidgetId
|
||||
const state = {
|
||||
name: 'store-name',
|
||||
label: 'Store Label',
|
||||
y: 42,
|
||||
type: 'combo',
|
||||
options: { values: ['a'] },
|
||||
value: 'a'
|
||||
}
|
||||
mocks.widgets.set(widgetId, state)
|
||||
|
||||
const widget = promotedInputWidget(input({ widgetId, label: undefined }))
|
||||
|
||||
expect(widget?.name).toBe('store-name')
|
||||
expect(widget?.label).toBe('Store Label')
|
||||
expect(widget?.y).toBe(42)
|
||||
expect(widget?.type).toBe('combo')
|
||||
expect(widget?.options).toEqual({ values: ['a'] })
|
||||
expect(widget?.value).toBe('a')
|
||||
|
||||
widget!.label = 'New Label'
|
||||
widget!.y = 52
|
||||
|
||||
expect(state.label).toBe('New Label')
|
||||
expect(state.y).toBe(52)
|
||||
})
|
||||
|
||||
it('returns null for non-promoted inputs and filters projected widget lists', () => {
|
||||
const widgetId = 'graph:12:prompt' as WidgetId
|
||||
const graphNode = node({
|
||||
inputs: [input(), input({ widgetId })]
|
||||
})
|
||||
|
||||
expect(promotedInputWidget(input())).toBeNull()
|
||||
expect(promotedInputWidgets(graphNode)).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('returns undefined for null stored values', () => {
|
||||
const widgetId = 'graph:12:prompt' as WidgetId
|
||||
mocks.widgets.set(widgetId, { value: null })
|
||||
|
||||
expect(promotedInputWidget(input({ widgetId }))?.value).toBeUndefined()
|
||||
})
|
||||
|
||||
it('delegates input lookup to the graph node', () => {
|
||||
const widget = { name: 'prompt' } as IBaseWidget
|
||||
const backingInput = input({ widgetId: 'graph:12:prompt' as WidgetId })
|
||||
const graphNode = node({
|
||||
getSlotFromWidget: vi.fn(() => backingInput)
|
||||
})
|
||||
|
||||
expect(inputForWidget(graphNode, widget)).toBe(backingInput)
|
||||
expect(graphNode.getSlotFromWidget).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
})
|
||||
@@ -15,6 +15,10 @@ import { usePreviewExposureStore } from '@/stores/previewExposureStore'
|
||||
import { useWidgetValueStore } from '@/stores/widgetValueStore'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import type { WidgetId } from '@/types/widgetId'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import type { Subgraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
function promotedInputNames(host: {
|
||||
inputs: Array<{ widgetId?: unknown; name: string }>
|
||||
@@ -51,19 +55,37 @@ vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ updatePreviews: updatePreviewsMock })
|
||||
}))
|
||||
|
||||
const addBreadcrumbMock = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@sentry/vue', () => ({
|
||||
addBreadcrumb: addBreadcrumbMock
|
||||
}))
|
||||
|
||||
const mockNavigation = vi.hoisted(() => ({
|
||||
stack: [] as Subgraph[]
|
||||
}))
|
||||
vi.mock('@/stores/subgraphNavigationStore', () => ({
|
||||
useSubgraphNavigationStore: () => ({
|
||||
navigationStack: mockNavigation.stack
|
||||
})
|
||||
}))
|
||||
|
||||
import {
|
||||
CANVAS_IMAGE_PREVIEW_WIDGET,
|
||||
addWidgetPromotionOptions,
|
||||
autoExposeKnownPreviewNodes,
|
||||
demoteWidget,
|
||||
getPromotableWidgets,
|
||||
hasUnpromotedWidgets,
|
||||
isLinkedPromotion,
|
||||
isPreviewPseudoWidget,
|
||||
isWidgetPromotedOnSubgraphNode,
|
||||
promoteWidget,
|
||||
promoteValueWidgetViaSubgraphInput,
|
||||
promoteRecommendedWidgets,
|
||||
pruneDisconnected,
|
||||
reorderSubgraphInputsByName,
|
||||
reorderSubgraphInputsByWidgetOrder
|
||||
reorderSubgraphInputsByWidgetOrder,
|
||||
tryToggleWidgetPromotion
|
||||
} from './promotionUtils'
|
||||
|
||||
function widget(
|
||||
@@ -102,6 +124,11 @@ function buildDuplicateNamePromotion() {
|
||||
return { subgraph, host, nodeA, widgetA, nodeB, widgetB }
|
||||
}
|
||||
|
||||
function setupNavigation(host: SubgraphNode) {
|
||||
host.subgraph.rootGraph.add(host)
|
||||
mockNavigation.stack = [host.subgraph]
|
||||
}
|
||||
|
||||
describe('isPreviewPseudoWidget', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
@@ -303,6 +330,284 @@ describe('getPromotableWidgets', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('widget promotion actions', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
addBreadcrumbMock.mockReset()
|
||||
mockNavigation.stack = []
|
||||
})
|
||||
|
||||
function setupPromotableWidget() {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
setupNavigation(host)
|
||||
const node = new LGraphNode('Prompt')
|
||||
subgraph.add(node)
|
||||
const input = node.addInput('text', 'STRING')
|
||||
input.label = 'Prompt text'
|
||||
const callback = vi.fn()
|
||||
const textWidget = node.addWidget('text', 'text', 'value', callback)
|
||||
textWidget.label = 'Prompt'
|
||||
input.widget = { name: textWidget.name }
|
||||
return { host, node, textWidget, callback }
|
||||
}
|
||||
|
||||
it('adds a promote menu option and runs the widget callback after promotion', () => {
|
||||
const { host, node, textWidget, callback } = setupPromotableWidget()
|
||||
const options: Parameters<typeof addWidgetPromotionOptions>[0] = []
|
||||
|
||||
addWidgetPromotionOptions(options, textWidget, node)
|
||||
const menuCallback = options[0]?.callback as
|
||||
| ((...args: unknown[]) => unknown)
|
||||
| undefined
|
||||
void menuCallback?.(null, undefined, undefined)
|
||||
|
||||
expect(options[0]?.content).toContain('Prompt')
|
||||
expect(isLinkedPromotion(host, String(node.id), textWidget.name)).toBe(true)
|
||||
expect(callback).toHaveBeenCalledWith('value')
|
||||
})
|
||||
|
||||
it('adds an unpromote menu option when the widget is already promoted', () => {
|
||||
const { host, node, textWidget, callback } = setupPromotableWidget()
|
||||
expect(promoteValueWidgetViaSubgraphInput(host, node, textWidget).ok).toBe(
|
||||
true
|
||||
)
|
||||
const options: Parameters<typeof addWidgetPromotionOptions>[0] = []
|
||||
|
||||
addWidgetPromotionOptions(options, textWidget, node)
|
||||
const menuCallback = options[0]?.callback as
|
||||
| ((...args: unknown[]) => unknown)
|
||||
| undefined
|
||||
void menuCallback?.(null, undefined, undefined)
|
||||
|
||||
expect(isLinkedPromotion(host, String(node.id), textWidget.name)).toBe(
|
||||
false
|
||||
)
|
||||
expect(callback).toHaveBeenCalledWith('value')
|
||||
})
|
||||
|
||||
it('reports outside-subgraph promotion attempts through the toast store', () => {
|
||||
const node = new LGraphNode('Prompt')
|
||||
const textWidget = node.addWidget('text', 'text', 'value', () => {})
|
||||
const options: Parameters<typeof addWidgetPromotionOptions>[0] = []
|
||||
|
||||
addWidgetPromotionOptions(options, textWidget, node)
|
||||
|
||||
expect(useToastStore().messagesToAdd).toHaveLength(1)
|
||||
expect(options).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('toggles promotion for the widget under the canvas pointer', () => {
|
||||
const { host, node, textWidget } = setupPromotableWidget()
|
||||
const canvas = fromPartial<ReturnType<typeof useCanvasStore>['canvas']>({
|
||||
graph_mouse: [10, 20],
|
||||
visible_nodes: [node],
|
||||
setDirty: vi.fn(),
|
||||
graph: {
|
||||
getNodeOnPos: vi.fn(() => node)
|
||||
}
|
||||
})
|
||||
vi.spyOn(node, 'getWidgetOnPos').mockReturnValue(textWidget)
|
||||
useCanvasStore().canvas = canvas
|
||||
|
||||
tryToggleWidgetPromotion()
|
||||
expect(isLinkedPromotion(host, String(node.id), textWidget.name)).toBe(true)
|
||||
|
||||
tryToggleWidgetPromotion()
|
||||
expect(isLinkedPromotion(host, String(node.id), textWidget.name)).toBe(
|
||||
false
|
||||
)
|
||||
})
|
||||
|
||||
it('leaves state unchanged when toggle has no node or widget target', () => {
|
||||
const { host, node, textWidget } = setupPromotableWidget()
|
||||
useCanvasStore().canvas = fromPartial<
|
||||
ReturnType<typeof useCanvasStore>['canvas']
|
||||
>({
|
||||
graph_mouse: [0, 0],
|
||||
visible_nodes: [],
|
||||
setDirty: vi.fn(),
|
||||
graph: {
|
||||
getNodeOnPos: vi.fn(() => null)
|
||||
}
|
||||
})
|
||||
|
||||
tryToggleWidgetPromotion()
|
||||
expect(isLinkedPromotion(host, String(node.id), textWidget.name)).toBe(
|
||||
false
|
||||
)
|
||||
|
||||
useCanvasStore().canvas = fromPartial<
|
||||
ReturnType<typeof useCanvasStore>['canvas']
|
||||
>({
|
||||
graph_mouse: [0, 0],
|
||||
visible_nodes: [node],
|
||||
setDirty: vi.fn(),
|
||||
graph: {
|
||||
getNodeOnPos: vi.fn(() => node)
|
||||
}
|
||||
})
|
||||
vi.spyOn(node, 'getWidgetOnPos').mockReturnValue(undefined)
|
||||
|
||||
tryToggleWidgetPromotion()
|
||||
expect(isLinkedPromotion(host, String(node.id), textWidget.name)).toBe(
|
||||
false
|
||||
)
|
||||
})
|
||||
|
||||
it('records a breadcrumb when value promotion has no source slot', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const node = new LGraphNode('LooseWidgetNode')
|
||||
subgraph.add(node)
|
||||
const looseWidget = node.addWidget('text', 'loose', 'value', () => {})
|
||||
|
||||
promoteWidget(node, looseWidget, [host])
|
||||
|
||||
expect(addBreadcrumbMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
level: 'warning',
|
||||
message: expect.stringContaining('missingSourceSlot')
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('ignores promotion calls for node-shaped values that are not graph nodes', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const partialNode = {
|
||||
id: toNodeId(123),
|
||||
title: 'Partial',
|
||||
type: 'Partial'
|
||||
}
|
||||
|
||||
promoteWidget(partialNode, widget({ name: 'seed', type: 'number' }), [host])
|
||||
|
||||
expect(host.subgraph.inputs).toEqual([])
|
||||
expect(addBreadcrumbMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses the widget name in menu text when label is absent', () => {
|
||||
const { node, textWidget } = setupPromotableWidget()
|
||||
textWidget.label = undefined
|
||||
const options: Parameters<typeof addWidgetPromotionOptions>[0] = []
|
||||
|
||||
addWidgetPromotionOptions(options, textWidget, node)
|
||||
|
||||
expect(options[0]?.content).toContain('text')
|
||||
})
|
||||
})
|
||||
|
||||
describe('preview promotion actions', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
addBreadcrumbMock.mockReset()
|
||||
mockNavigation.stack = []
|
||||
})
|
||||
|
||||
it('identifies preview exposure as promotion only for preview pseudo widgets', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const previewNode = new LGraphNode('PreviewImage')
|
||||
previewNode.type = 'PreviewImage'
|
||||
subgraph.add(previewNode)
|
||||
const previewWidget = widget({
|
||||
name: CANVAS_IMAGE_PREVIEW_WIDGET,
|
||||
serialize: false,
|
||||
type: 'preview'
|
||||
})
|
||||
usePreviewExposureStore().addExposure(host.rootGraph.id, String(host.id), {
|
||||
sourceNodeId: previewNode.id,
|
||||
sourcePreviewName: CANVAS_IMAGE_PREVIEW_WIDGET
|
||||
})
|
||||
|
||||
expect(
|
||||
isWidgetPromotedOnSubgraphNode(
|
||||
host,
|
||||
{
|
||||
sourceNodeId: previewNode.id,
|
||||
sourceWidgetName: CANVAS_IMAGE_PREVIEW_WIDGET
|
||||
},
|
||||
previewWidget
|
||||
)
|
||||
).toBe(true)
|
||||
expect(
|
||||
isWidgetPromotedOnSubgraphNode(
|
||||
host,
|
||||
{
|
||||
sourceNodeId: previewNode.id,
|
||||
sourceWidgetName: 'other'
|
||||
},
|
||||
previewWidget
|
||||
)
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('deduplicates preview exposures when the same preview is promoted twice', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const previewNode = new LGraphNode('PreviewImage')
|
||||
previewNode.type = 'PreviewImage'
|
||||
subgraph.add(previewNode)
|
||||
const previewWidget = widget({
|
||||
name: CANVAS_IMAGE_PREVIEW_WIDGET,
|
||||
serialize: false,
|
||||
type: 'preview'
|
||||
})
|
||||
|
||||
promoteWidget(previewNode, previewWidget, [host])
|
||||
promoteWidget(previewNode, previewWidget, [host])
|
||||
|
||||
expect(
|
||||
usePreviewExposureStore().getExposures(host.rootGraph.id, String(host.id))
|
||||
).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('demotes preview exposures when no linked value promotion exists', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const previewNode = new LGraphNode('PreviewImage')
|
||||
previewNode.type = 'PreviewImage'
|
||||
subgraph.add(previewNode)
|
||||
const previewWidget = widget({
|
||||
name: CANVAS_IMAGE_PREVIEW_WIDGET,
|
||||
serialize: false,
|
||||
type: 'preview'
|
||||
})
|
||||
promoteWidget(previewNode, previewWidget, [host])
|
||||
|
||||
demoteWidget(previewNode, previewWidget, [host])
|
||||
|
||||
expect(
|
||||
usePreviewExposureStore().getExposures(host.rootGraph.id, String(host.id))
|
||||
).toEqual([])
|
||||
})
|
||||
|
||||
it('leaves unexposed preview widgets unchanged when demoted', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const previewNode = new LGraphNode('PreviewImage')
|
||||
previewNode.type = 'PreviewImage'
|
||||
subgraph.add(previewNode)
|
||||
const previewWidget = widget({
|
||||
name: CANVAS_IMAGE_PREVIEW_WIDGET,
|
||||
serialize: false,
|
||||
type: 'preview'
|
||||
})
|
||||
|
||||
demoteWidget(previewNode, previewWidget, [host])
|
||||
|
||||
expect(
|
||||
usePreviewExposureStore().getExposures(host.rootGraph.id, String(host.id))
|
||||
).toEqual([])
|
||||
expect(addBreadcrumbMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.stringContaining(CANVAS_IMAGE_PREVIEW_WIDGET)
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('promoteRecommendedWidgets', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
@@ -346,6 +651,49 @@ describe('promoteRecommendedWidgets', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('keeps value promotion idempotent when the widget is already linked', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const interiorNode = new LGraphNode('Prompt')
|
||||
const input = interiorNode.addInput('text', 'STRING')
|
||||
const textWidget = interiorNode.addWidget('text', 'text', '', () => {})
|
||||
input.widget = { name: textWidget.name }
|
||||
subgraph.add(interiorNode)
|
||||
|
||||
expect(
|
||||
promoteValueWidgetViaSubgraphInput(subgraphNode, interiorNode, textWidget)
|
||||
.ok
|
||||
).toBe(true)
|
||||
expect(
|
||||
promoteValueWidgetViaSubgraphInput(subgraphNode, interiorNode, textWidget)
|
||||
.ok
|
||||
).toBe(true)
|
||||
|
||||
expect(subgraph.inputs.map((slot) => slot.name)).toEqual(['text'])
|
||||
})
|
||||
|
||||
it('seeds outer promoted widget state from a nested promoted input', () => {
|
||||
const { host: innerHost } = buildDuplicateNamePromotion()
|
||||
writePromotedInputValue(innerHost, 'text', 'inner value')
|
||||
const outerSubgraph = createTestSubgraph()
|
||||
const outerHost = createTestSubgraphNode(outerSubgraph)
|
||||
outerSubgraph.add(innerHost)
|
||||
|
||||
expect(
|
||||
promoteValueWidgetViaSubgraphInput(
|
||||
outerHost,
|
||||
innerHost,
|
||||
promotedWidgetRef(innerHost, 'text')
|
||||
).ok
|
||||
).toBe(true)
|
||||
|
||||
const hostInput = outerHost.inputs.find((input) => input.name === 'text')
|
||||
if (!hostInput?.widgetId) throw new Error('Missing promoted host widget id')
|
||||
expect(useWidgetValueStore().getWidget(hostInput.widgetId)?.value).toBe(
|
||||
'inner value'
|
||||
)
|
||||
})
|
||||
|
||||
it('promotes virtual previews through preview exposures', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
@@ -414,6 +762,24 @@ describe('promoteRecommendedWidgets', () => {
|
||||
})
|
||||
expect(updatePreviewsMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('records a breadcrumb when a recommended value widget has no source slot', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const interiorNode = new LGraphNode('CLIPTextEncode')
|
||||
interiorNode.type = 'CLIPTextEncode'
|
||||
interiorNode.addWidget('text', 'text', '', () => {})
|
||||
subgraph.add(interiorNode)
|
||||
|
||||
promoteRecommendedWidgets(subgraphNode)
|
||||
|
||||
expect(addBreadcrumbMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
level: 'warning',
|
||||
message: expect.stringContaining('missingSourceSlot')
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('autoExposeKnownPreviewNodes', () => {
|
||||
@@ -482,6 +848,52 @@ describe('autoExposeKnownPreviewNodes', () => {
|
||||
.map((e) => e.sourceNodeId)
|
||||
).not.toContain(String(glslNode.id))
|
||||
})
|
||||
|
||||
it('defers preview discovery for nodes without eager preview widgets', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const interiorNode = new LGraphNode('DeferredPreview')
|
||||
const rafCallbacks: FrameRequestCallback[] = []
|
||||
const requestAnimationFrameSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
rafCallbacks.push(callback)
|
||||
return rafCallbacks.length
|
||||
})
|
||||
subgraph.add(interiorNode)
|
||||
|
||||
try {
|
||||
autoExposeKnownPreviewNodes(subgraphNode)
|
||||
rafCallbacks[0]?.(0)
|
||||
const updateCallback = updatePreviewsMock.mock.calls[0]?.[1]
|
||||
const previewWidget = interiorNode.addWidget(
|
||||
'preview' as Parameters<typeof interiorNode.addWidget>[0],
|
||||
'preview',
|
||||
'',
|
||||
() => {}
|
||||
)
|
||||
previewWidget.serialize = false
|
||||
previewWidget.type = 'preview'
|
||||
updateCallback?.()
|
||||
|
||||
expect(updatePreviewsMock).toHaveBeenCalledWith(
|
||||
interiorNode,
|
||||
expect.any(Function)
|
||||
)
|
||||
expect(
|
||||
usePreviewExposureStore().getExposures(
|
||||
subgraphNode.rootGraph.id,
|
||||
String(subgraphNode.id)
|
||||
)
|
||||
).toContainEqual({
|
||||
name: 'preview',
|
||||
sourceNodeId: String(interiorNode.id),
|
||||
sourcePreviewName: 'preview'
|
||||
})
|
||||
} finally {
|
||||
requestAnimationFrameSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasUnpromotedWidgets', () => {
|
||||
@@ -673,6 +1085,25 @@ describe('reorderSubgraphInputsByName', () => {
|
||||
])
|
||||
})
|
||||
|
||||
it('leaves unordered names after explicitly ordered inputs', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
{ name: 'first', type: 'number' },
|
||||
{ name: 'second', type: 'number' },
|
||||
{ name: 'third', type: 'number' }
|
||||
]
|
||||
})
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
|
||||
reorderSubgraphInputsByName(host, ['second'])
|
||||
|
||||
expect(host.subgraph.inputs.map((input) => input.name)).toEqual([
|
||||
'second',
|
||||
'first',
|
||||
'third'
|
||||
])
|
||||
})
|
||||
|
||||
it('updates subgraph input link slot indices after reordering', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
@@ -768,6 +1199,33 @@ describe('reorderSubgraphInputsByWidgetOrder', () => {
|
||||
'first value'
|
||||
])
|
||||
})
|
||||
|
||||
it('appends promoted inputs that are absent from the widget order', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const firstNode = new LGraphNode('First')
|
||||
const secondNode = new LGraphNode('Second')
|
||||
subgraph.add(firstNode)
|
||||
subgraph.add(secondNode)
|
||||
|
||||
const firstInput = firstNode.addInput('first', 'STRING')
|
||||
const firstWidget = firstNode.addWidget('text', 'first', '', () => {})
|
||||
firstInput.widget = { name: firstWidget.name }
|
||||
const secondInput = secondNode.addInput('second', 'STRING')
|
||||
const secondWidget = secondNode.addWidget('text', 'second', '', () => {})
|
||||
secondInput.widget = { name: secondWidget.name }
|
||||
promoteValueWidgetViaSubgraphInput(host, firstNode, firstWidget)
|
||||
promoteValueWidgetViaSubgraphInput(host, secondNode, secondWidget)
|
||||
|
||||
reorderSubgraphInputsByWidgetOrder(host, [
|
||||
promotedWidgetRef(host, 'second')
|
||||
])
|
||||
|
||||
expect(host.subgraph.inputs.map((input) => input.name)).toEqual([
|
||||
'second',
|
||||
'first'
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('demoteWidget — axiomatic projection retraction', () => {
|
||||
@@ -798,6 +1256,23 @@ describe('demoteWidget — axiomatic projection retraction', () => {
|
||||
return { host, interiorNode, interiorWidget }
|
||||
}
|
||||
|
||||
it('runs as a no-op for an unpromoted non-preview widget', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const interiorNode = new LGraphNode('TestNode')
|
||||
host.subgraph.add(interiorNode)
|
||||
const widget = interiorNode.addWidget('text', 'value', 'initial', () => {})
|
||||
|
||||
demoteWidget(interiorNode, widget, [host])
|
||||
|
||||
expect(host.subgraph.inputs).toEqual([])
|
||||
expect(addBreadcrumbMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.stringContaining('Demoted widget "value"')
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('drops projection but keeps slot and external link when host slot is externally connected', () => {
|
||||
const { host, interiorNode, interiorWidget } = setupPromotedWidget()
|
||||
const hostInput = host.inputs[0]
|
||||
@@ -943,4 +1418,54 @@ describe('disambiguated nested promotion identity', () => {
|
||||
|
||||
expect(outerHost.subgraph.inputs).toHaveLength(beforeCount)
|
||||
})
|
||||
|
||||
it('promotes a widget whose source widget state is missing', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const interiorNode = new LGraphNode('Source')
|
||||
subgraph.add(interiorNode)
|
||||
const interiorInput = interiorNode.addInput('text', 'STRING')
|
||||
const interiorWidget = interiorNode.addWidget('text', 'text', '', () => {})
|
||||
interiorInput.widget = { name: interiorWidget.name }
|
||||
interiorInput.widgetId = 'missing-widget-state' as WidgetId
|
||||
|
||||
expect(
|
||||
promoteValueWidgetViaSubgraphInput(host, interiorNode, interiorWidget).ok
|
||||
).toBe(true)
|
||||
expect(host.subgraph.inputs.map((input) => input.name)).toEqual(['text'])
|
||||
})
|
||||
|
||||
it('keeps plain inputs after ordered promoted widgets', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'plain', type: 'STRING' }]
|
||||
})
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
|
||||
reorderSubgraphInputsByWidgetOrder(host, [
|
||||
{ widgetId: 'missing-widget-state' as WidgetId }
|
||||
])
|
||||
|
||||
expect(host.inputs.map((input) => input.name)).toEqual(['plain'])
|
||||
})
|
||||
|
||||
it('falls back to append order when promoted input links are stale', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const interiorNode = new LGraphNode('Source')
|
||||
subgraph.add(interiorNode)
|
||||
const interiorInput = interiorNode.addInput('text', 'STRING')
|
||||
const interiorWidget = interiorNode.addWidget('text', 'text', '', () => {})
|
||||
interiorInput.widget = { name: interiorWidget.name }
|
||||
|
||||
expect(
|
||||
promoteValueWidgetViaSubgraphInput(host, interiorNode, interiorWidget).ok
|
||||
).toBe(true)
|
||||
const promotedInput = host.subgraph.inputs[0]
|
||||
const linkId = promotedInput.linkIds[0]
|
||||
host.subgraph.links.delete(linkId)
|
||||
|
||||
reorderSubgraphInputsByWidgetOrder(host, [promotedWidgetRef(host, 'text')])
|
||||
|
||||
expect(host.inputs.map((input) => input.name)).toEqual(['text'])
|
||||
})
|
||||
})
|
||||
|
||||
52
src/core/graph/widgets/dynamicTypes.test.ts
Normal file
52
src/core/graph/widgets/dynamicTypes.test.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { resolveInputType } from './dynamicTypes'
|
||||
|
||||
describe('resolveInputType', () => {
|
||||
it('splits concrete comma-delimited input types', () => {
|
||||
expect(resolveInputType({ type: 'MODEL,CLIP' } as never)).toEqual([
|
||||
'MODEL',
|
||||
'CLIP'
|
||||
])
|
||||
})
|
||||
|
||||
it('resolves match-type templates from allowed types', () => {
|
||||
expect(
|
||||
resolveInputType({
|
||||
type: 'COMFY_MATCHTYPE_V3',
|
||||
template: {
|
||||
allowed_types: 'IMAGE,MASK',
|
||||
template_id: 'image'
|
||||
}
|
||||
} as never)
|
||||
).toEqual(['IMAGE', 'MASK'])
|
||||
})
|
||||
|
||||
it('returns an empty type list for invalid match-type templates', () => {
|
||||
expect(resolveInputType({ type: 'COMFY_MATCHTYPE_V3' } as never)).toEqual(
|
||||
[]
|
||||
)
|
||||
})
|
||||
|
||||
it('resolves autogrow templates from required and optional inputs', () => {
|
||||
expect(
|
||||
resolveInputType({
|
||||
type: 'COMFY_AUTOGROW_V3',
|
||||
template: {
|
||||
input: {
|
||||
required: {
|
||||
image: ['IMAGE', {}]
|
||||
},
|
||||
optional: {
|
||||
mask: ['MASK,IMAGE', {}]
|
||||
}
|
||||
}
|
||||
}
|
||||
} as never)
|
||||
).toEqual(['IMAGE', 'MASK', 'IMAGE'])
|
||||
})
|
||||
|
||||
it('returns an empty type list for invalid autogrow templates', () => {
|
||||
expect(resolveInputType({ type: 'COMFY_AUTOGROW_V3' } as never)).toEqual([])
|
||||
})
|
||||
})
|
||||
@@ -1,13 +1,19 @@
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { describe, expect, test, vi } from 'vitest'
|
||||
import { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, test, vi } from 'vitest'
|
||||
import { LGraph, LGraphNode, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { transformInputSpecV1ToV2 } from '@/schemas/nodeDef/migration'
|
||||
import { app } from '@/scripts/app'
|
||||
import type { InputSpec } from '@/schemas/nodeDefSchema'
|
||||
import type { InputSpec as InputSpecV2 } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import type { HasInitialMinSize } from '@/services/litegraphService'
|
||||
import { useWidgetValueStore } from '@/stores/widgetValueStore'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { applyDynamicInputs, dynamicWidgets } from './dynamicWidgets'
|
||||
|
||||
setActivePinia(createTestingPinia())
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
type DynamicInputs = ('INT' | 'STRING' | 'IMAGE' | DynamicInputs)[][]
|
||||
type TestAutogrowNode = LGraphNode & {
|
||||
comfyDynamic: { autogrow: Record<string, unknown> }
|
||||
@@ -15,6 +21,13 @@ type TestAutogrowNode = LGraphNode & {
|
||||
|
||||
const { addNodeInput } = useLitegraphService()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
fromAny<{ configuringGraphLevel: number }, unknown>(
|
||||
app
|
||||
).configuringGraphLevel = 0
|
||||
})
|
||||
|
||||
function nextTick() {
|
||||
return new Promise<void>((r) => requestAnimationFrame(() => r()))
|
||||
}
|
||||
@@ -56,6 +69,23 @@ function addAutogrow(node: LGraphNode, template: unknown) {
|
||||
})
|
||||
)
|
||||
}
|
||||
function addMatchType(
|
||||
node: LGraphNode,
|
||||
name: string,
|
||||
allowedTypes = '*',
|
||||
templateId = 'a'
|
||||
) {
|
||||
addNodeInput(
|
||||
node,
|
||||
transformInputSpecV1ToV2(
|
||||
[
|
||||
'COMFY_MATCHTYPE_V3',
|
||||
{ template: { allowed_types: allowedTypes, template_id: templateId } }
|
||||
],
|
||||
{ name, isOptional: false }
|
||||
)
|
||||
)
|
||||
}
|
||||
function connectInput(node: LGraphNode, inputIndex: number, graph: LGraph) {
|
||||
const node2 = testNode()
|
||||
node2.addOutput('out', '*')
|
||||
@@ -116,7 +146,312 @@ describe('Dynamic Combos', () => {
|
||||
node.widgets[0].value = '1'
|
||||
expect.soft(node.widgets[1].tooltip).toBe('1')
|
||||
})
|
||||
|
||||
test('throws for malformed dynamic combo specs before creating a widget', () => {
|
||||
const node = testNode()
|
||||
const comboApp = { widgets: { COMBO: vi.fn() } } as unknown as Parameters<
|
||||
typeof dynamicWidgets.COMFY_DYNAMICCOMBO_V3
|
||||
>[3]
|
||||
|
||||
expect(() =>
|
||||
dynamicWidgets.COMFY_DYNAMICCOMBO_V3(
|
||||
node,
|
||||
'bad',
|
||||
['COMFY_DYNAMICCOMBO_V3', {}] as InputSpec,
|
||||
comboApp
|
||||
)
|
||||
).toThrow('invalid DynamicCombo spec')
|
||||
expect(comboApp.widgets.COMBO).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test('clears grouped widgets when selection becomes empty', () => {
|
||||
const node = testNode()
|
||||
addDynamicCombo(node, [['INT'], ['INT', 'STRING']])
|
||||
node.widgets[0].value = '1'
|
||||
const onRemove = vi.fn()
|
||||
node.widgets[1].onRemove = onRemove
|
||||
|
||||
node.widgets[0].value = undefined
|
||||
|
||||
expect(onRemove).toHaveBeenCalled()
|
||||
expect(node.widgets).toHaveLength(1)
|
||||
})
|
||||
|
||||
test('deletes widget state when removing grouped dynamic widgets', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addDynamicCombo(node, [['INT'], ['STRING']])
|
||||
const childWidget = node.widgets[1]
|
||||
const childWidgetId = childWidget.widgetId
|
||||
if (!childWidgetId) throw new Error('Missing child widget id')
|
||||
const deleteWidget = vi.mocked(useWidgetValueStore().deleteWidget)
|
||||
|
||||
node.widgets[0].value = undefined
|
||||
|
||||
expect(deleteWidget).toHaveBeenCalledWith(childWidgetId)
|
||||
})
|
||||
|
||||
test('preserves an existing dynamic input link when refreshing a selection', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
const onConnectionsChange = vi.fn()
|
||||
node.onConnectionsChange = onConnectionsChange
|
||||
graph.add(node)
|
||||
addDynamicCombo(node, [['IMAGE'], ['STRING']])
|
||||
node.widgets[0].value = '0'
|
||||
|
||||
connectInput(node, 1, graph)
|
||||
const linkId = node.inputs[1].link
|
||||
expect(linkId).not.toBeNull()
|
||||
onConnectionsChange.mockClear()
|
||||
|
||||
node.widgets[0].value = '0'
|
||||
|
||||
expect(node.inputs[1].link).toBe(linkId)
|
||||
expect(graph.links[linkId!].target_slot).toBe(1)
|
||||
expect(onConnectionsChange).toHaveBeenCalledWith(
|
||||
LiteGraph.INPUT,
|
||||
1,
|
||||
true,
|
||||
graph.links[linkId!],
|
||||
node.inputs[1]
|
||||
)
|
||||
})
|
||||
|
||||
test('throws if the backing widgets array disappears during update', () => {
|
||||
const node = testNode()
|
||||
addDynamicCombo(node, [['INT'], ['STRING']])
|
||||
const controller = node.widgets[0]
|
||||
node.widgets = undefined as unknown as typeof node.widgets
|
||||
|
||||
expect(() => {
|
||||
controller.value = '1'
|
||||
}).toThrow('Not Reachable')
|
||||
})
|
||||
|
||||
test('throws when the dynamic controller widget is missing during update', () => {
|
||||
const node = testNode()
|
||||
addDynamicCombo(node, [['INT'], ['STRING']])
|
||||
const controller = node.widgets[0]
|
||||
node.widgets = node.widgets.slice(1)
|
||||
|
||||
expect(() => {
|
||||
controller.value = '1'
|
||||
}).toThrow("Dynamic widget doesn't exist on node")
|
||||
})
|
||||
|
||||
test('throws when input-only dynamic sockets have no insertion point', () => {
|
||||
const node = testNode()
|
||||
addDynamicCombo(node, [['INT'], ['IMAGE']])
|
||||
const controller = node.widgets[0]
|
||||
node.inputs = []
|
||||
|
||||
expect(() => {
|
||||
controller.value = '1'
|
||||
}).toThrow('Failed to find input socket for 0')
|
||||
})
|
||||
|
||||
test('updates dynamic inputs without requiring a graph', () => {
|
||||
const node = testNode()
|
||||
addDynamicCombo(node, [['INT'], ['IMAGE']])
|
||||
|
||||
node.widgets[0].value = '1'
|
||||
|
||||
expect(node.inputs[1].type).toBe('IMAGE')
|
||||
})
|
||||
|
||||
test('reads dynamic combo values from widget state when available', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addDynamicCombo(node, [['INT'], ['STRING']])
|
||||
const controller = node.widgets[0]
|
||||
const controllerId = controller.widgetId
|
||||
if (!controllerId) throw new Error('Missing controller widget id')
|
||||
|
||||
controller.value = '1'
|
||||
useWidgetValueStore().setValue(controllerId, '0')
|
||||
|
||||
expect(controller.value).toBe('0')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Dynamic input dispatch', () => {
|
||||
test('returns false for unknown dynamic input types', () => {
|
||||
const node = testNode()
|
||||
|
||||
expect(
|
||||
applyDynamicInputs(node, {
|
||||
name: 'plain',
|
||||
type: 'STRING',
|
||||
isOptional: false
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
test('returns true after applying a known dynamic input type', () => {
|
||||
const node = testNode()
|
||||
|
||||
expect(
|
||||
applyDynamicInputs(
|
||||
node,
|
||||
transformInputSpecV1ToV2(
|
||||
[
|
||||
'COMFY_AUTOGROW_V3',
|
||||
{ template: { input: { required: { image: ['IMAGE', {}] } } } }
|
||||
],
|
||||
{ name: 'grow', isOptional: false }
|
||||
)
|
||||
)
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
test('throws when an autogrow input spec is malformed', () => {
|
||||
const node = testNode()
|
||||
const inputSpec = {
|
||||
name: 'bad',
|
||||
type: 'COMFY_AUTOGROW_V3'
|
||||
} as InputSpecV2
|
||||
|
||||
expect(() => addNodeInput(node, inputSpec)).toThrow('invalid Autogrow spec')
|
||||
})
|
||||
|
||||
test('ignores malformed match type specs', () => {
|
||||
const node = testNode()
|
||||
|
||||
expect(
|
||||
applyDynamicInputs(node, {
|
||||
name: 'bad',
|
||||
type: 'COMFY_MATCHTYPE_V3',
|
||||
isOptional: false
|
||||
})
|
||||
).toBe(true)
|
||||
expect(node.inputs).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('MatchType inputs', () => {
|
||||
function createMatchTypeNode(graph: LGraph, outputMatchTypes = ['a']) {
|
||||
const node = testNode()
|
||||
node.constructor.nodeData = {
|
||||
name: 'testnode',
|
||||
output_matchtypes: outputMatchTypes
|
||||
} as typeof node.constructor.nodeData
|
||||
node.addOutput('out', '*')
|
||||
graph.add(node)
|
||||
addMatchType(node, 'on_true')
|
||||
addMatchType(node, 'on_false')
|
||||
return node
|
||||
}
|
||||
|
||||
function createSourceNode(graph: LGraph, type: string) {
|
||||
const node = testNode()
|
||||
node.addOutput('out', type)
|
||||
graph.add(node)
|
||||
return node
|
||||
}
|
||||
|
||||
test('ignores match type notifications outside registered inputs', () => {
|
||||
const graph = new LGraph()
|
||||
const node = createMatchTypeNode(graph)
|
||||
node.addInput('plain', 'STRING')
|
||||
|
||||
node.onConnectionsChange?.(LiteGraph.OUTPUT, 0, true, null, node.inputs[0])
|
||||
node.onConnectionsChange?.(LiteGraph.INPUT, 2, true, null, node.inputs[2])
|
||||
|
||||
expect(node.outputs[0].type).toBe('*')
|
||||
})
|
||||
|
||||
test('uses wildcard types for stale match type links', () => {
|
||||
const graph = new LGraph()
|
||||
const node = createMatchTypeNode(graph)
|
||||
node.inputs[0].link = toLinkId(999)
|
||||
|
||||
node.onConnectionsChange?.(LiteGraph.INPUT, 1, false, null, node.inputs[1])
|
||||
|
||||
expect(node.outputs[0].type).toBe('*')
|
||||
})
|
||||
|
||||
test('leaves unmatched output groups unchanged', () => {
|
||||
const graph = new LGraph()
|
||||
const node = createMatchTypeNode(graph, ['other'])
|
||||
const source = createSourceNode(graph, 'IMAGE')
|
||||
|
||||
source.connect(0, node, 0)
|
||||
|
||||
expect(node.outputs[0].type).toBe('*')
|
||||
})
|
||||
|
||||
test('throws when match group input constraints cannot overlap', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
const requestAnimationFrameSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation(() => 1)
|
||||
node.constructor.nodeData = {
|
||||
name: 'testnode',
|
||||
output_matchtypes: ['a']
|
||||
} as typeof node.constructor.nodeData
|
||||
node.addOutput('out', '*')
|
||||
graph.add(node)
|
||||
addMatchType(node, 'image', 'IMAGE')
|
||||
addMatchType(node, 'latent', 'LATENT')
|
||||
const source = createSourceNode(graph, 'IMAGE')
|
||||
|
||||
try {
|
||||
expect(() => source.connect(0, node, 0)).toThrow('invalid connection')
|
||||
} finally {
|
||||
requestAnimationFrameSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
test('disconnects downstream links when a match type output narrows', () => {
|
||||
const graph = new LGraph()
|
||||
const node = createMatchTypeNode(graph)
|
||||
const downstream = testNode()
|
||||
downstream.addInput('latent', 'LATENT')
|
||||
downstream.onConnectionsChange = vi.fn()
|
||||
graph.add(downstream)
|
||||
node.connect(0, downstream, 0)
|
||||
const source = createSourceNode(graph, 'IMAGE')
|
||||
|
||||
source.connect(0, node, 0)
|
||||
|
||||
expect(downstream.inputs[0].link).toBeNull()
|
||||
expect(downstream.onConnectionsChange).toHaveBeenCalledWith(
|
||||
LiteGraph.INPUT,
|
||||
0,
|
||||
false,
|
||||
expect.anything(),
|
||||
downstream.inputs[0]
|
||||
)
|
||||
})
|
||||
|
||||
test('ignores deferred match type refresh after the input is removed', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
const rafCallbacks: FrameRequestCallback[] = []
|
||||
const requestAnimationFrameSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
rafCallbacks.push(callback)
|
||||
return rafCallbacks.length
|
||||
})
|
||||
graph.add(node)
|
||||
|
||||
try {
|
||||
addMatchType(node, 'removed')
|
||||
node.inputs.pop()
|
||||
rafCallbacks[0]?.(0)
|
||||
|
||||
expect(node.inputs).toHaveLength(0)
|
||||
} finally {
|
||||
requestAnimationFrameSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Autogrow', () => {
|
||||
const inputsSpec = { required: { image: ['IMAGE', {}] } }
|
||||
test('Can name by prefix', () => {
|
||||
@@ -162,6 +497,259 @@ describe('Autogrow', () => {
|
||||
connectInput(node, 2, graph)
|
||||
expect(node.inputs.length).toBe(3)
|
||||
})
|
||||
|
||||
test('ignores autogrow notifications that cannot affect a known input group', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
const inputCount = node.inputs.length
|
||||
const unknownInput = node.addInput('outside.0', 'IMAGE')
|
||||
|
||||
node.onConnectionsChange?.(LiteGraph.OUTPUT, 0, true, null, node.inputs[0])
|
||||
node.onConnectionsChange?.(
|
||||
LiteGraph.INPUT,
|
||||
99,
|
||||
true,
|
||||
null,
|
||||
fromAny<
|
||||
Parameters<NonNullable<typeof node.onConnectionsChange>>[4],
|
||||
unknown
|
||||
>(undefined)
|
||||
)
|
||||
node.onConnectionsChange?.(LiteGraph.INPUT, 2, true, null, unknownInput)
|
||||
|
||||
expect(node.inputs).toHaveLength(inputCount + 1)
|
||||
})
|
||||
|
||||
test('does not grow autogrow inputs when connection metadata is missing', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
|
||||
node.onConnectionsChange?.(LiteGraph.INPUT, 1, true, null, node.inputs[1])
|
||||
|
||||
expect(node.inputs).toHaveLength(2)
|
||||
})
|
||||
|
||||
test('keeps minimum autogrow rows when disconnecting early ordinals', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 2, input: inputsSpec, prefix: 'test' })
|
||||
|
||||
node.onConnectionsChange?.(LiteGraph.INPUT, 0, false, null, node.inputs[0])
|
||||
await nextTick()
|
||||
|
||||
expect(node.inputs).toHaveLength(3)
|
||||
})
|
||||
|
||||
test('restores a configure-time autogrow widget shim', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
node.inputs[1].widget = { name: node.inputs[1].name }
|
||||
fromAny<{ configuringGraphLevel: number }, unknown>(
|
||||
app
|
||||
).configuringGraphLevel = 1
|
||||
|
||||
connectInput(node, 1, graph)
|
||||
|
||||
expect(node.widgets.some((widget) => widget.name === '0.test1')).toBe(true)
|
||||
})
|
||||
|
||||
test('draws configure-time autogrow shim text from the input name', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
node.inputs[1].widget = { name: node.inputs[1].name }
|
||||
fromAny<{ configuringGraphLevel: number }, unknown>(
|
||||
app
|
||||
).configuringGraphLevel = 1
|
||||
|
||||
connectInput(node, 1, graph)
|
||||
const shim = node.widgets.find((widget) => widget.name === '0.test1')
|
||||
if (!shim?.draw) throw new Error('Missing shim widget')
|
||||
node.inputs[1].label = undefined
|
||||
const ctx = fromAny<CanvasRenderingContext2D, unknown>({
|
||||
save: vi.fn(),
|
||||
fillText: vi.fn(),
|
||||
restore: vi.fn()
|
||||
})
|
||||
|
||||
shim.draw(ctx, node, 100, 10, 20)
|
||||
|
||||
expect(ctx.fillText).toHaveBeenCalledWith('0.test1', 20, 25)
|
||||
})
|
||||
|
||||
test('keeps an existing configure-time autogrow widget shim', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
node.inputs[1].widget = { name: node.inputs[1].name }
|
||||
node.widgets.push({
|
||||
name: node.inputs[1].name,
|
||||
type: 'shim',
|
||||
y: 0,
|
||||
options: {},
|
||||
serialize: false,
|
||||
draw: vi.fn()
|
||||
})
|
||||
fromAny<{ configuringGraphLevel: number }, unknown>(
|
||||
app
|
||||
).configuringGraphLevel = 1
|
||||
|
||||
connectInput(node, 1, graph)
|
||||
|
||||
expect(
|
||||
node.widgets.filter((widget) => widget.name === '0.test1')
|
||||
).toHaveLength(1)
|
||||
})
|
||||
|
||||
test('defers disconnect handling during an input swap', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
const rafCallbacks: FrameRequestCallback[] = []
|
||||
const requestAnimationFrameSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
rafCallbacks.push(callback)
|
||||
return rafCallbacks.length
|
||||
})
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
|
||||
try {
|
||||
connectInput(node, 0, graph)
|
||||
node.disconnectInput(0)
|
||||
|
||||
expect(node.inputs).toHaveLength(2)
|
||||
expect(rafCallbacks).toHaveLength(2)
|
||||
} finally {
|
||||
requestAnimationFrameSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
test('stops cleanup for uneven multi-input autogrow groups', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => undefined)
|
||||
graph.add(node)
|
||||
addAutogrow(node, {
|
||||
min: 1,
|
||||
input: { required: { image: ['IMAGE', {}], mask: ['MASK', {}] } }
|
||||
})
|
||||
node.inputs.pop()
|
||||
|
||||
try {
|
||||
node.onConnectionsChange?.(
|
||||
LiteGraph.INPUT,
|
||||
0,
|
||||
false,
|
||||
null,
|
||||
node.inputs[0]
|
||||
)
|
||||
await nextTick()
|
||||
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Failed to group multi-input autogrow inputs'
|
||||
)
|
||||
} finally {
|
||||
consoleErrorSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
test('keeps trailing autogrow row when disconnecting the last slot', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
|
||||
node.onConnectionsChange?.(LiteGraph.INPUT, 1, false, null, node.inputs[1])
|
||||
await nextTick()
|
||||
|
||||
expect(node.inputs.map((input) => input.name)).toEqual([
|
||||
'0.test0',
|
||||
'0.test1'
|
||||
])
|
||||
})
|
||||
|
||||
test('ignores named autogrow input names outside the configured list', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, names: ['a', 'b'] })
|
||||
const unknownInput = node.addInput('0.c', 'IMAGE')
|
||||
|
||||
node.onConnectionsChange?.(
|
||||
LiteGraph.INPUT,
|
||||
node.inputs.length - 1,
|
||||
false,
|
||||
null,
|
||||
unknownInput
|
||||
)
|
||||
await nextTick()
|
||||
|
||||
expect(node.inputs.map((input) => input.name)).toEqual([
|
||||
'0.a',
|
||||
'0.b',
|
||||
'0.c'
|
||||
])
|
||||
})
|
||||
|
||||
test('ignores autogrow input names without numeric ordinals', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test' })
|
||||
const unknownInput = node.addInput('0.testx', 'IMAGE')
|
||||
|
||||
node.onConnectionsChange?.(
|
||||
LiteGraph.INPUT,
|
||||
node.inputs.length - 1,
|
||||
false,
|
||||
null,
|
||||
unknownInput
|
||||
)
|
||||
await nextTick()
|
||||
|
||||
expect(node.inputs.map((input) => input.name)).toEqual([
|
||||
'0.test0',
|
||||
'0.test1',
|
||||
'0.testx'
|
||||
])
|
||||
})
|
||||
|
||||
test('marks optional autogrow inputs as optional after required inputs', () => {
|
||||
const node = testNode()
|
||||
|
||||
addAutogrow(node, {
|
||||
min: 1,
|
||||
input: {
|
||||
required: { image: ['IMAGE', {}] },
|
||||
optional: { mask: ['MASK', {}] }
|
||||
}
|
||||
})
|
||||
|
||||
expect(node.inputs.map((input) => input.name)).toEqual([
|
||||
'0.image0',
|
||||
'0.mask0',
|
||||
'0.image1',
|
||||
'0.mask1'
|
||||
])
|
||||
expect(node.inputs.map((input) => input.type)).toEqual([
|
||||
'IMAGE',
|
||||
'MASK',
|
||||
'IMAGE',
|
||||
'MASK'
|
||||
])
|
||||
})
|
||||
test('Removing connections decreases to min + 1', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
@@ -258,6 +846,42 @@ describe('Autogrow', () => {
|
||||
expect(vid0Link).not.toBeNull()
|
||||
expect(graph.links[vid0Link!].target_slot).toBe(vid0Index)
|
||||
})
|
||||
|
||||
test('removes shim widgets when multi-input autogrow rows shrink', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, {
|
||||
min: 1,
|
||||
input: { required: { image: ['IMAGE', {}], mask: ['MASK', {}] } }
|
||||
})
|
||||
connectInput(node, 2, graph)
|
||||
await nextTick()
|
||||
expect(node.inputs).toHaveLength(6)
|
||||
|
||||
const removedWidgetNames = ['0.image2', '0.mask2']
|
||||
const onRemove = vi.fn()
|
||||
for (const widget of node.widgets.filter((widget) =>
|
||||
removedWidgetNames.includes(widget.name)
|
||||
)) {
|
||||
widget.onRemove = onRemove
|
||||
}
|
||||
|
||||
node.disconnectInput(2)
|
||||
await nextTick()
|
||||
|
||||
expect(node.inputs.map((input) => input.name)).toEqual([
|
||||
'0.image0',
|
||||
'0.mask0',
|
||||
'0.image1',
|
||||
'0.mask1'
|
||||
])
|
||||
expect(onRemove).toHaveBeenCalledTimes(2)
|
||||
expect(
|
||||
node.widgets.some((widget) => removedWidgetNames.includes(widget.name))
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
test('Can deserialize a complex node', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
|
||||
@@ -127,4 +127,45 @@ describe('MatchType during configure', () => {
|
||||
expect(switchNode.inputs[1].link).not.toBeNull()
|
||||
expect(switchNode.outputs[0].type).toBe('IMAGE')
|
||||
})
|
||||
|
||||
test('keeps compatible downstream links after output type recalculation', () => {
|
||||
const graph = new LGraph()
|
||||
const switchNode = createMatchTypeNode(graph)
|
||||
const target = new LGraphNode('target')
|
||||
target.addInput('image', 'IMAGE')
|
||||
target.onConnectionsChange = vi.fn()
|
||||
graph.add(target)
|
||||
const source = createSourceNode(graph, 'IMAGE')
|
||||
|
||||
switchNode.connect(0, target, 0)
|
||||
vi.mocked(target.onConnectionsChange).mockClear()
|
||||
source.connect(0, switchNode, 0)
|
||||
|
||||
expect(switchNode.outputs[0].type).toBe('IMAGE')
|
||||
expect(target.inputs[0].link).not.toBeNull()
|
||||
expect(target.onConnectionsChange).toHaveBeenCalledWith(
|
||||
LiteGraph.INPUT,
|
||||
0,
|
||||
true,
|
||||
expect.anything(),
|
||||
target.inputs[0]
|
||||
)
|
||||
})
|
||||
|
||||
test('disconnects incompatible downstream links after output type recalculation', () => {
|
||||
const graph = new LGraph()
|
||||
const switchNode = createMatchTypeNode(graph)
|
||||
const target = new LGraphNode('target')
|
||||
target.addInput('image', 'IMAGE')
|
||||
graph.add(target)
|
||||
const source = createSourceNode(graph, 'LATENT')
|
||||
|
||||
switchNode.connect(0, target, 0)
|
||||
expect(target.inputs[0].link).not.toBeNull()
|
||||
|
||||
source.connect(0, switchNode, 0)
|
||||
|
||||
expect(switchNode.outputs[0].type).toBe('LATENT')
|
||||
expect(target.inputs[0].link).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,16 +1,48 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { SerialisedLLinkArray } from '@/lib/litegraph/src/LLink'
|
||||
import { LGraphNode, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ComfyNode } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ComfyNodeDef } from '@/schemas/nodeDefSchema'
|
||||
import type { ComfyApp } from '@/scripts/app'
|
||||
import type { ComfyExtension } from '@/types/comfy'
|
||||
|
||||
import type { GroupNodeWorkflowData } from './groupNode'
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
registerExtension: vi.fn()
|
||||
const appMock = vi.hoisted(() => ({
|
||||
canvas: {
|
||||
emitAfterChange: vi.fn(),
|
||||
emitBeforeChange: vi.fn(),
|
||||
selected_nodes: {}
|
||||
},
|
||||
registerExtension: vi.fn(),
|
||||
registerNodeDef: vi.fn(),
|
||||
rootGraph: {
|
||||
convertToSubgraph: vi.fn(),
|
||||
extra: {},
|
||||
getNodeById: vi.fn(),
|
||||
links: {},
|
||||
nodes: [],
|
||||
remove: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
const widgetStoreMock = vi.hoisted(() => ({
|
||||
inputIsWidget: vi.fn((spec: unknown[]) =>
|
||||
['BOOLEAN', 'COMBO', 'FLOAT', 'INT', 'STRING'].includes(String(spec[0]))
|
||||
)
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: appMock
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => widgetStoreMock
|
||||
}))
|
||||
|
||||
import { GroupNodeConfig, replaceLegacySeparators } from './groupNode'
|
||||
|
||||
function makeNode(type: string): ComfyNode {
|
||||
@@ -26,6 +58,46 @@ function makeNode(type: string): ComfyNode {
|
||||
}
|
||||
}
|
||||
|
||||
function makeNodeDef(overrides: Partial<ComfyNodeDef> = {}): ComfyNodeDef {
|
||||
return {
|
||||
name: 'TestNode',
|
||||
display_name: 'Test Node',
|
||||
description: '',
|
||||
category: 'test',
|
||||
input: { required: {}, optional: {} },
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
output_node: false,
|
||||
python_module: 'test',
|
||||
...overrides
|
||||
} as ComfyNodeDef
|
||||
}
|
||||
|
||||
function extension(): ComfyExtension {
|
||||
const groupExtension = appMock.registerExtension.mock.calls.find(
|
||||
([registered]) => registered.name === 'Comfy.GroupNode'
|
||||
)?.[0]
|
||||
if (!groupExtension) throw new Error('GroupNode extension was not registered')
|
||||
return groupExtension as ComfyExtension
|
||||
}
|
||||
|
||||
function addCustomNodeDefs(defs: Record<string, ComfyNodeDef>) {
|
||||
const groupExtension = extension()
|
||||
if (!groupExtension.addCustomNodeDefs) {
|
||||
throw new Error('GroupNode extension does not implement addCustomNodeDefs')
|
||||
}
|
||||
groupExtension.addCustomNodeDefs(defs, appMock as unknown as ComfyApp)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
appMock.registerNodeDef.mockReset()
|
||||
widgetStoreMock.inputIsWidget.mockClear()
|
||||
LiteGraph.registered_node_types = {}
|
||||
addCustomNodeDefs({})
|
||||
})
|
||||
|
||||
describe('replaceLegacySeparators', () => {
|
||||
it('rewrites the legacy "workflow/" prefix to "workflow>"', () => {
|
||||
const nodes = [makeNode('workflow/My Group')]
|
||||
@@ -104,4 +176,390 @@ describe('GroupNodeConfig.getLinks', () => {
|
||||
const config = configFrom([], [[0, 1, 'IMAGE']])
|
||||
expect(config.externalFrom[0][1]).toBe('IMAGE')
|
||||
})
|
||||
|
||||
it('ignores external links without a type and accumulates multiple slots', () => {
|
||||
const config = configFrom(
|
||||
[],
|
||||
[
|
||||
[0, 1, null as unknown as string],
|
||||
[0, 2, 'LATENT'],
|
||||
[0, 3, 'IMAGE']
|
||||
]
|
||||
)
|
||||
|
||||
expect(config.externalFrom[0]).toEqual({ 2: 'LATENT', 3: 'IMAGE' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('GroupNodeConfig.getNodeDef', () => {
|
||||
const imageNodeDef = makeNodeDef({
|
||||
name: 'ImageNode',
|
||||
input: {
|
||||
required: {
|
||||
image: ['IMAGE', {}],
|
||||
mode: [['fast', 'slow'], {}]
|
||||
},
|
||||
optional: {
|
||||
strength: ['FLOAT', { default: 1 }]
|
||||
}
|
||||
},
|
||||
output: ['IMAGE'],
|
||||
output_name: ['image'],
|
||||
output_is_list: [false]
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
addCustomNodeDefs({ ImageNode: imageNodeDef })
|
||||
})
|
||||
|
||||
it('returns registered definitions for normal node types', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [{ index: 0, type: 'ImageNode' }],
|
||||
links: [],
|
||||
external: []
|
||||
})
|
||||
|
||||
expect(config.getNodeDef({ index: 0, type: 'ImageNode' })).toBe(
|
||||
imageNodeDef
|
||||
)
|
||||
})
|
||||
|
||||
it('returns undefined for nodes without an index or a known type', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [{ type: 'UnknownNode' }],
|
||||
links: [],
|
||||
external: []
|
||||
})
|
||||
|
||||
expect(config.getNodeDef({ type: 'UnknownNode' })).toBeUndefined()
|
||||
})
|
||||
|
||||
it('skips unlinked primitive nodes', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [{ index: 0, type: 'PrimitiveNode' }],
|
||||
links: [],
|
||||
external: []
|
||||
})
|
||||
|
||||
expect(
|
||||
config.getNodeDef({ index: 0, type: 'PrimitiveNode' })
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('derives primitive node type from the outgoing link type', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [
|
||||
{ index: 0, type: 'PrimitiveNode' },
|
||||
{ index: 1, type: 'ImageNode' }
|
||||
],
|
||||
links: [[0, 0, 1, 0, 1, 'IMAGE'] as SerialisedLLinkArray],
|
||||
external: []
|
||||
})
|
||||
|
||||
expect(
|
||||
config.getNodeDef({ index: 0, type: 'PrimitiveNode' })
|
||||
).toMatchObject({
|
||||
input: { required: { value: ['IMAGE', {}] } },
|
||||
output: ['IMAGE']
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to null when primitive combo target spec is not primitive', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [
|
||||
{
|
||||
index: 0,
|
||||
type: 'PrimitiveNode',
|
||||
outputs: [{ name: 'mode', widget: { name: 'mode' } }]
|
||||
},
|
||||
{ index: 1, type: 'ImageNode' }
|
||||
],
|
||||
links: [[0, 0, 1, 0, 1, 'COMBO'] as SerialisedLLinkArray],
|
||||
external: []
|
||||
})
|
||||
|
||||
expect(config.getNodeDef(config.nodeData.nodes[0])).toMatchObject({
|
||||
input: { required: { value: [null, {}] } },
|
||||
output: [null]
|
||||
})
|
||||
})
|
||||
|
||||
it('returns null for reroutes used only inside the group', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [
|
||||
{ index: 0, type: 'ImageNode' },
|
||||
{ index: 1, type: 'Reroute' },
|
||||
{ index: 2, type: 'ImageNode' }
|
||||
],
|
||||
links: [
|
||||
[0, 0, 1, 0, 1, 'IMAGE'],
|
||||
[1, 0, 2, 0, 2, 'IMAGE']
|
||||
] as SerialisedLLinkArray[],
|
||||
external: []
|
||||
})
|
||||
|
||||
expect(config.getNodeDef({ index: 1, type: 'Reroute' })).toBeNull()
|
||||
})
|
||||
|
||||
it('derives reroute type from outgoing target inputs', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [
|
||||
{ index: 0, type: 'Reroute' },
|
||||
{
|
||||
index: 1,
|
||||
type: 'ImageNode',
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
}
|
||||
],
|
||||
links: [[0, 0, 1, 0, 1, 'IMAGE'] as SerialisedLLinkArray],
|
||||
external: [[0, 0, 'IMAGE']]
|
||||
})
|
||||
|
||||
expect(config.getNodeDef({ index: 0, type: 'Reroute' })).toMatchObject({
|
||||
input: { required: { IMAGE: ['IMAGE', { forceInput: true }] } },
|
||||
output: ['IMAGE']
|
||||
})
|
||||
})
|
||||
|
||||
it('derives reroute type from incoming output metadata', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [
|
||||
{ index: 0, type: 'ImageNode', outputs: [{ type: 'LATENT' }] },
|
||||
{ index: 1, type: 'Reroute' }
|
||||
],
|
||||
links: [[0, 0, 1, 0, 1, 'LATENT'] as SerialisedLLinkArray],
|
||||
external: [[1, 0, 'LATENT']]
|
||||
})
|
||||
|
||||
expect(config.getNodeDef({ index: 1, type: 'Reroute' })).toMatchObject({
|
||||
input: { required: { LATENT: ['LATENT', { forceInput: true }] } },
|
||||
output: ['LATENT']
|
||||
})
|
||||
})
|
||||
|
||||
it('derives pipe reroute type from external metadata when links omit it', () => {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [{ index: 0, type: 'Reroute' }],
|
||||
links: [],
|
||||
external: [[0, 0, 'MASK']]
|
||||
})
|
||||
|
||||
expect(config.getNodeDef({ index: 0, type: 'Reroute' })).toMatchObject({
|
||||
input: { required: { MASK: ['MASK', { forceInput: true }] } },
|
||||
output: ['MASK']
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('GroupNodeConfig input and output mapping', () => {
|
||||
function configWithNode(node: GroupNodeWorkflowData['nodes'][number]) {
|
||||
const config = new GroupNodeConfig('group', {
|
||||
nodes: [node],
|
||||
links: [],
|
||||
external: [],
|
||||
config: {
|
||||
0: {
|
||||
input: {
|
||||
hidden: { visible: false },
|
||||
renamed: { name: 'Custom Name' }
|
||||
},
|
||||
output: {
|
||||
1: { name: 'Custom Output' },
|
||||
2: { visible: false }
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
config.nodeDef = makeNodeDef({
|
||||
input: { required: {} },
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: []
|
||||
})
|
||||
return config
|
||||
}
|
||||
|
||||
it('renames duplicate inputs and adds seed control metadata', () => {
|
||||
const config = configWithNode({
|
||||
index: 0,
|
||||
type: 'Sampler',
|
||||
title: 'Sampler A',
|
||||
inputs: [{ name: 'seed', label: 'Seed Label' }]
|
||||
})
|
||||
const seenInputs = { seed: 1, 'Sampler A seed': 1 }
|
||||
const result = config.getInputConfig(
|
||||
{ index: 0, type: 'Sampler', title: 'Sampler A' },
|
||||
'seed',
|
||||
seenInputs,
|
||||
['INT', {}]
|
||||
)
|
||||
|
||||
expect(result.name).toBe('Sampler A 1 seed')
|
||||
expect(result.config).toEqual([
|
||||
'INT',
|
||||
{ control_after_generate: 'Sampler A control_after_generate' }
|
||||
])
|
||||
})
|
||||
|
||||
it('maps image upload widget aliases through converted widget names', () => {
|
||||
const config = configWithNode({ index: 0, type: 'LoadImage' })
|
||||
config.oldToNewWidgetMap[0] = { customImage: 'Uploaded Image' }
|
||||
|
||||
expect(
|
||||
config.getInputConfig({ index: 0, type: 'LoadImage' }, 'renamed', {}, [
|
||||
'IMAGEUPLOAD',
|
||||
{ widget: 'customImage' }
|
||||
])
|
||||
).toMatchObject({
|
||||
name: 'Custom Name',
|
||||
config: ['IMAGEUPLOAD', { widget: 'Uploaded Image' }]
|
||||
})
|
||||
})
|
||||
|
||||
it('splits widget inputs, socket inputs, and converted widget slots', () => {
|
||||
const config = configWithNode({
|
||||
index: 0,
|
||||
type: 'MixedNode',
|
||||
inputs: [{ name: 'mode', widget: { name: 'mode' } }]
|
||||
})
|
||||
|
||||
const result = config.processWidgetInputs(
|
||||
{
|
||||
mode: ['COMBO', {}],
|
||||
image: ['IMAGE', {}]
|
||||
},
|
||||
{
|
||||
index: 0,
|
||||
type: 'MixedNode',
|
||||
inputs: [{ name: 'mode', widget: { name: 'mode' } }]
|
||||
},
|
||||
['mode', 'image'],
|
||||
{}
|
||||
)
|
||||
|
||||
expect(result.slots).toEqual(['image'])
|
||||
expect(result.converted.get(0)).toBe('mode')
|
||||
expect(config.oldToNewWidgetMap[0].mode).toBeNull()
|
||||
})
|
||||
|
||||
it('adds visible unlinked input slots and skips hidden configured inputs', () => {
|
||||
const config = configWithNode({
|
||||
index: 0,
|
||||
type: 'InputNode'
|
||||
})
|
||||
const inputMap: Record<number, number> = {}
|
||||
config.processInputSlots(
|
||||
{
|
||||
image: ['IMAGE', {}],
|
||||
hidden: ['LATENT', {}]
|
||||
},
|
||||
{ index: 0, type: 'InputNode' },
|
||||
['image', 'hidden'],
|
||||
{},
|
||||
inputMap,
|
||||
{}
|
||||
)
|
||||
|
||||
expect(config.nodeDef?.input?.required).toEqual({ image: ['IMAGE', {}] })
|
||||
expect(inputMap).toEqual({ 0: 0 })
|
||||
})
|
||||
|
||||
it('adds output metadata, hides linked/internal outputs, and dedupes labels', () => {
|
||||
const config = configWithNode({
|
||||
index: 0,
|
||||
type: 'OutputNode',
|
||||
title: 'Output A',
|
||||
outputs: [{ name: 'image', label: 'Rendered' }]
|
||||
})
|
||||
config.linksFrom[0] = {
|
||||
0: [[0, 0, 1, 0, 1, 'IMAGE'] as SerialisedLLinkArray]
|
||||
}
|
||||
config.processNodeOutputs(
|
||||
{ index: 0, type: 'OutputNode', title: 'Output A' },
|
||||
{ Rendered: 1 },
|
||||
{
|
||||
input: { required: {} },
|
||||
output: ['IMAGE', 'LATENT', 'MASK'],
|
||||
output_name: ['image', 'latent', 'mask'],
|
||||
output_is_list: [false, true, false]
|
||||
}
|
||||
)
|
||||
|
||||
expect(config.outputVisibility).toEqual([false, true, false])
|
||||
expect(config.nodeDef?.output).toEqual(['LATENT'])
|
||||
expect(config.nodeDef?.output_is_list).toEqual([true])
|
||||
expect(config.nodeDef?.output_name).toEqual(['Custom Output'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('GroupNodeConfig.registerFromWorkflow', () => {
|
||||
it('adds missing type actions and skips registration for incomplete groups', async () => {
|
||||
const groupNodes: Record<string, GroupNodeWorkflowData> = {
|
||||
Broken: {
|
||||
nodes: [{ index: 0, type: 'MissingNode' }],
|
||||
links: [],
|
||||
external: []
|
||||
}
|
||||
}
|
||||
const missingNodeTypes: Parameters<
|
||||
typeof GroupNodeConfig.registerFromWorkflow
|
||||
>[1] = []
|
||||
|
||||
await GroupNodeConfig.registerFromWorkflow(groupNodes, missingNodeTypes)
|
||||
|
||||
expect(appMock.registerNodeDef).not.toHaveBeenCalled()
|
||||
expect(missingNodeTypes).toHaveLength(2)
|
||||
expect(missingNodeTypes[0]).toMatchObject({
|
||||
type: 'MissingNode',
|
||||
hint: " (In group node 'workflow>Broken')"
|
||||
})
|
||||
|
||||
const action = missingNodeTypes[1]
|
||||
if (typeof action === 'string') {
|
||||
throw new Error('Expected an action entry for the broken group node')
|
||||
}
|
||||
const target = document.createElement('button')
|
||||
const { callback } = action.action as {
|
||||
callback: (event: MouseEvent) => void
|
||||
}
|
||||
const event = new MouseEvent('click')
|
||||
Object.defineProperty(event, 'target', { value: target })
|
||||
callback(event)
|
||||
expect(groupNodes.Broken).toBeUndefined()
|
||||
expect(target.textContent).toBe('Removed')
|
||||
expect(target.style.pointerEvents).toBe('none')
|
||||
})
|
||||
|
||||
it('registers complete group node types and stores their generated node defs', async () => {
|
||||
addCustomNodeDefs({
|
||||
ImageNode: makeNodeDef({
|
||||
name: 'ImageNode',
|
||||
input: { required: { image: ['IMAGE', {}] } },
|
||||
output: ['IMAGE'],
|
||||
output_name: ['image'],
|
||||
output_is_list: [false]
|
||||
})
|
||||
})
|
||||
LiteGraph.registered_node_types.ImageNode = class extends LGraphNode {}
|
||||
|
||||
await GroupNodeConfig.registerFromWorkflow(
|
||||
{
|
||||
Complete: {
|
||||
nodes: [{ index: 0, type: 'ImageNode' }],
|
||||
links: [],
|
||||
external: [[0, 0, 'IMAGE']]
|
||||
}
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
expect(appMock.registerNodeDef).toHaveBeenCalledWith(
|
||||
'workflow>Complete',
|
||||
expect.objectContaining({
|
||||
category: 'group nodes>workflow',
|
||||
display_name: 'Complete',
|
||||
name: 'workflow>Complete'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,18 +1,89 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type {
|
||||
INodeInputSlot,
|
||||
INodeOutputSlot
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import {
|
||||
NodeInputSlot,
|
||||
NodeOutputSlot,
|
||||
inputAsSerialisable,
|
||||
outputAsSerialisable
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { ReadOnlyRect } from '@/lib/litegraph/src/interfaces'
|
||||
import { SlotType } from '@/lib/litegraph/src/draw'
|
||||
import type {
|
||||
DefaultConnectionColors,
|
||||
ReadOnlyRect
|
||||
} from '@/lib/litegraph/src/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import {
|
||||
LinkDirection,
|
||||
RenderShape
|
||||
} from '@/lib/litegraph/src/types/globalEnums'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
|
||||
const boundingRect: ReadOnlyRect = [0, 0, 10, 10]
|
||||
|
||||
type MockCanvasContext = CanvasRenderingContext2D & {
|
||||
arc: ReturnType<typeof vi.fn>
|
||||
beginPath: ReturnType<typeof vi.fn>
|
||||
clip: ReturnType<typeof vi.fn>
|
||||
closePath: ReturnType<typeof vi.fn>
|
||||
fill: ReturnType<typeof vi.fn>
|
||||
fillText: ReturnType<typeof vi.fn>
|
||||
lineTo: ReturnType<typeof vi.fn>
|
||||
moveTo: ReturnType<typeof vi.fn>
|
||||
rect: ReturnType<typeof vi.fn>
|
||||
restore: ReturnType<typeof vi.fn>
|
||||
save: ReturnType<typeof vi.fn>
|
||||
stroke: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
function createContext(): MockCanvasContext {
|
||||
return {
|
||||
fillStyle: '#initial-fill',
|
||||
strokeStyle: '#initial-stroke',
|
||||
lineWidth: 7,
|
||||
textAlign: 'start',
|
||||
arc: vi.fn(),
|
||||
beginPath: vi.fn(),
|
||||
clip: vi.fn(),
|
||||
closePath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillText: vi.fn(),
|
||||
lineTo: vi.fn(),
|
||||
moveTo: vi.fn(),
|
||||
rect: vi.fn(),
|
||||
restore: vi.fn(),
|
||||
save: vi.fn(),
|
||||
stroke: vi.fn()
|
||||
} as unknown as MockCanvasContext
|
||||
}
|
||||
|
||||
function createColors(): DefaultConnectionColors {
|
||||
return {
|
||||
getConnectedColor: vi.fn((type) => `connected-${type}`),
|
||||
getDisconnectedColor: vi.fn((type) => `disconnected-${type}`)
|
||||
}
|
||||
}
|
||||
|
||||
function createNode(): LGraphNode {
|
||||
return {
|
||||
pos: [100, 200],
|
||||
_collapsed_width: 80
|
||||
} as LGraphNode
|
||||
}
|
||||
|
||||
describe('NodeSlot', () => {
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal(
|
||||
'Path2D',
|
||||
class {
|
||||
arc = vi.fn()
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
describe('inputAsSerialisable', () => {
|
||||
it('removes _data from serialized slot', () => {
|
||||
const slot: INodeOutputSlot = {
|
||||
@@ -74,4 +145,328 @@ describe('NodeSlot', () => {
|
||||
expect(serialized.widget).not.toHaveProperty('options')
|
||||
})
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('draws an input label on the right and restores canvas styles', () => {
|
||||
const ctx = createContext()
|
||||
const slot = new NodeInputSlot(
|
||||
{
|
||||
name: 'input',
|
||||
label: 'Input label',
|
||||
type: 'FLOAT',
|
||||
link: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
createNode()
|
||||
)
|
||||
|
||||
slot.draw(ctx, { colorContext: createColors(), highlight: true })
|
||||
|
||||
expect(ctx.arc).toHaveBeenCalledWith(15, 15, 5, 0, Math.PI * 2)
|
||||
expect(ctx.fillText).toHaveBeenCalledWith('Input label', 25, 20)
|
||||
expect(ctx.fillStyle).toBe('#initial-fill')
|
||||
expect(ctx.strokeStyle).toBe('#initial-stroke')
|
||||
expect(ctx.lineWidth).toBe(7)
|
||||
expect(ctx.textAlign).toBe('start')
|
||||
})
|
||||
|
||||
it('draws output labels on the left and strokes output slots', () => {
|
||||
const ctx = createContext()
|
||||
const slot = new NodeOutputSlot(
|
||||
{
|
||||
name: 'output',
|
||||
localized_name: 'Localized output',
|
||||
type: 'FLOAT',
|
||||
links: [toLinkId(1)],
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
createNode()
|
||||
)
|
||||
|
||||
slot.draw(ctx, { colorContext: createColors() })
|
||||
|
||||
expect(ctx.stroke).toHaveBeenCalled()
|
||||
expect(ctx.fillText).toHaveBeenCalledWith('Localized output', 5, 20)
|
||||
expect(ctx.textAlign).toBe('start')
|
||||
expect(ctx.strokeStyle).toBe('#initial-stroke')
|
||||
})
|
||||
|
||||
it('draws event, box, arrow, grid, and low-quality slot shapes', () => {
|
||||
const colorContext = createColors()
|
||||
const node = createNode()
|
||||
const eventCtx = createContext()
|
||||
const boxCtx = createContext()
|
||||
const arrowCtx = createContext()
|
||||
const gridCtx = createContext()
|
||||
const lowQualityCtx = createContext()
|
||||
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'event',
|
||||
type: SlotType.Event,
|
||||
link: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
).draw(eventCtx, { colorContext })
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'box',
|
||||
type: 'FLOAT',
|
||||
shape: RenderShape.BOX,
|
||||
link: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
).draw(boxCtx, { colorContext })
|
||||
new NodeOutputSlot(
|
||||
{
|
||||
name: 'arrow',
|
||||
type: 'FLOAT',
|
||||
shape: RenderShape.ARROW,
|
||||
links: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
).draw(arrowCtx, { colorContext })
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'grid',
|
||||
type: SlotType.Array,
|
||||
link: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
).draw(gridCtx, { colorContext })
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'low',
|
||||
type: 'FLOAT',
|
||||
link: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
).draw(lowQualityCtx, { colorContext, lowQuality: true })
|
||||
|
||||
expect(eventCtx.rect).toHaveBeenCalledWith(9.5, 10.5, 14, 10)
|
||||
expect(boxCtx.rect).toHaveBeenCalledWith(9.5, 10.5, 14, 10)
|
||||
expect(arrowCtx.moveTo).toHaveBeenCalledWith(23, 15.5)
|
||||
expect(gridCtx.rect).toHaveBeenCalledTimes(9)
|
||||
expect(lowQualityCtx.rect).toHaveBeenCalledWith(11, 11, 8, 8)
|
||||
expect(lowQualityCtx.fillText).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('draws hollow and multi-type slots', () => {
|
||||
const colorContext = createColors()
|
||||
const hollowCtx = createContext()
|
||||
const multiCtx = createContext()
|
||||
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'hollow',
|
||||
type: 'FLOAT',
|
||||
shape: RenderShape.HollowCircle,
|
||||
link: null,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
createNode()
|
||||
).draw(hollowCtx, { colorContext, highlight: true })
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'multi',
|
||||
type: 'A,B,C,D,E',
|
||||
link: toLinkId(1),
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
createNode()
|
||||
).draw(multiCtx, { colorContext })
|
||||
|
||||
expect(hollowCtx.clip).toHaveBeenCalledWith(expect.any(Object), 'evenodd')
|
||||
expect(
|
||||
vi
|
||||
.mocked(colorContext.getConnectedColor)
|
||||
.mock.calls.some(([type]) => type === 'A')
|
||||
).toBe(true)
|
||||
expect(multiCtx.fill.mock.calls.length).toBeGreaterThan(1)
|
||||
expect(multiCtx.stroke).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('hides widget input labels and draws error rings', () => {
|
||||
const ctx = createContext()
|
||||
const slot = new NodeInputSlot(
|
||||
{
|
||||
name: 'widget-input',
|
||||
label: 'Hidden label',
|
||||
type: 'FLOAT',
|
||||
link: null,
|
||||
widget: { name: 'widget' },
|
||||
hasErrors: true,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
createNode()
|
||||
)
|
||||
|
||||
slot.draw(ctx, { colorContext: createColors() })
|
||||
|
||||
expect(ctx.fillText).not.toHaveBeenCalled()
|
||||
expect(ctx.arc).toHaveBeenCalledWith(15, 15, 12, 0, Math.PI * 2)
|
||||
expect(ctx.stroke).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('places directional labels above vertical slots', () => {
|
||||
const rightCtx = createContext()
|
||||
const leftCtx = createContext()
|
||||
const node = createNode()
|
||||
const input = new NodeInputSlot(
|
||||
{
|
||||
name: 'up',
|
||||
type: 'FLOAT',
|
||||
link: null,
|
||||
dir: LinkDirection.UP,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
)
|
||||
const output = new NodeOutputSlot(
|
||||
{
|
||||
name: 'down',
|
||||
type: 'FLOAT',
|
||||
links: null,
|
||||
dir: LinkDirection.DOWN,
|
||||
boundingRect: [110, 210, 10, 10]
|
||||
},
|
||||
node
|
||||
)
|
||||
|
||||
input.draw(rightCtx, { colorContext: createColors() })
|
||||
output.draw(leftCtx, { colorContext: createColors() })
|
||||
|
||||
expect(rightCtx.fillText).toHaveBeenCalledWith('up', 15, 5)
|
||||
expect(leftCtx.fillText).toHaveBeenCalledWith('down', 15, 7)
|
||||
})
|
||||
})
|
||||
|
||||
describe('collapsed rendering', () => {
|
||||
it('draws collapsed input and output arrows in their own directions', () => {
|
||||
const inputCtx = createContext()
|
||||
const outputCtx = createContext()
|
||||
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'input',
|
||||
type: 'FLOAT',
|
||||
shape: RenderShape.ARROW,
|
||||
link: null,
|
||||
boundingRect
|
||||
},
|
||||
createNode()
|
||||
).drawCollapsed(inputCtx)
|
||||
new NodeOutputSlot(
|
||||
{
|
||||
name: 'output',
|
||||
type: 'FLOAT',
|
||||
shape: RenderShape.ARROW,
|
||||
links: null,
|
||||
boundingRect
|
||||
},
|
||||
createNode()
|
||||
).drawCollapsed(outputCtx)
|
||||
|
||||
expect(inputCtx.moveTo).toHaveBeenCalledWith(8, -15)
|
||||
expect(inputCtx.lineTo).toHaveBeenCalledWith(-4, -19)
|
||||
expect(outputCtx.moveTo).toHaveBeenCalledWith(86, -15)
|
||||
expect(outputCtx.lineTo).toHaveBeenCalledWith(74, -19)
|
||||
})
|
||||
|
||||
it('draws collapsed event and circle slots', () => {
|
||||
const eventCtx = createContext()
|
||||
const circleCtx = createContext()
|
||||
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'event',
|
||||
type: SlotType.Event,
|
||||
link: null,
|
||||
boundingRect
|
||||
},
|
||||
createNode()
|
||||
).drawCollapsed(eventCtx)
|
||||
new NodeInputSlot(
|
||||
{
|
||||
name: 'circle',
|
||||
type: 'FLOAT',
|
||||
link: null,
|
||||
boundingRect
|
||||
},
|
||||
createNode()
|
||||
).drawCollapsed(circleCtx)
|
||||
|
||||
expect(eventCtx.rect).toHaveBeenCalledWith(-6.5, -19, 14, 8)
|
||||
expect(circleCtx.arc).toHaveBeenCalledWith(0, -15, 4, 0, Math.PI * 2)
|
||||
expect(circleCtx.fillStyle).toBe('#initial-fill')
|
||||
})
|
||||
})
|
||||
|
||||
describe('serialization and validation', () => {
|
||||
it('serializes slot fields without the node reference', () => {
|
||||
const slot = new NodeOutputSlot(
|
||||
{
|
||||
name: 'out',
|
||||
type: 'FLOAT',
|
||||
label: 'Output',
|
||||
color_on: '#fff',
|
||||
color_off: '#000',
|
||||
shape: RenderShape.BOX,
|
||||
dir: LinkDirection.RIGHT,
|
||||
localized_name: 'Localized',
|
||||
pos: [1, 2],
|
||||
links: [toLinkId(3)],
|
||||
slot_index: 4,
|
||||
boundingRect: [1, 2, 3, 4]
|
||||
},
|
||||
createNode()
|
||||
)
|
||||
|
||||
expect(slot.toJSON()).toEqual({
|
||||
name: 'out',
|
||||
type: 'FLOAT',
|
||||
label: 'Output',
|
||||
color_on: '#fff',
|
||||
color_off: '#000',
|
||||
shape: RenderShape.BOX,
|
||||
dir: LinkDirection.RIGHT,
|
||||
localized_name: 'Localized',
|
||||
pos: [1, 2],
|
||||
boundingRect: [1, 2, 3, 4],
|
||||
links: [toLinkId(3)],
|
||||
slot_index: 4
|
||||
})
|
||||
})
|
||||
|
||||
it('validates input and output targets by slot direction', () => {
|
||||
const input = new NodeInputSlot(
|
||||
{
|
||||
name: 'input',
|
||||
type: 'FLOAT',
|
||||
link: null,
|
||||
boundingRect
|
||||
},
|
||||
createNode()
|
||||
)
|
||||
const output = new NodeOutputSlot(
|
||||
{
|
||||
name: 'output',
|
||||
type: 'FLOAT',
|
||||
links: null,
|
||||
boundingRect
|
||||
},
|
||||
createNode()
|
||||
)
|
||||
|
||||
expect(input.isValidTarget(output)).toBe(true)
|
||||
expect(output.isValidTarget(input)).toBe(true)
|
||||
expect(input.isValidTarget(input)).toBe(false)
|
||||
expect(output.isValidTarget(output)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
ExecutableNodeDTO,
|
||||
LGraph,
|
||||
LGraphEventMode,
|
||||
LLink,
|
||||
LGraphNode
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
@@ -24,6 +25,14 @@ beforeEach(() => {
|
||||
})
|
||||
|
||||
describe('ExecutableNodeDTO Creation', () => {
|
||||
it('should throw when the node has no graph', () => {
|
||||
const node = new LGraphNode('Detached')
|
||||
|
||||
expect(() => new ExecutableNodeDTO(node, [], new Map(), undefined)).toThrow(
|
||||
'Attempted to access LGraph reference that was null or undefined.'
|
||||
)
|
||||
})
|
||||
|
||||
it('should create DTO from regular node', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('Test Node')
|
||||
@@ -207,6 +216,74 @@ describe('ExecutableNodeDTO Input Resolution', () => {
|
||||
const resolved = dto.resolveInput(0)
|
||||
expect(resolved).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should throw when resolving a repeated input path', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('Looped')
|
||||
node.id = toNodeId(8)
|
||||
node.title = 'Loop title'
|
||||
node.addInput('in', 'IMAGE')
|
||||
graph.add(node)
|
||||
const dto = new ExecutableNodeDTO(node, ['parent'], new Map(), undefined)
|
||||
|
||||
expect(() =>
|
||||
dto.resolveInput(0, new Set([`undefined:${node.id}[I]0`]))
|
||||
).toThrow('Circular reference detected while resolving input 0')
|
||||
})
|
||||
|
||||
it('should report repeated root inputs without title or path details', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('')
|
||||
node.id = toNodeId(8)
|
||||
node.title = ''
|
||||
node.addInput('in', 'IMAGE')
|
||||
graph.add(node)
|
||||
const dto = new ExecutableNodeDTO(node, [], new Map(), undefined)
|
||||
|
||||
expect(() =>
|
||||
dto.resolveInput(0, new Set([`undefined:${node.id}[I]0`]))
|
||||
).toThrow('Circular reference detected while resolving input 0 of node 8')
|
||||
})
|
||||
|
||||
it('should throw when an input points at a missing link', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('Target')
|
||||
node.addInput('in', 'IMAGE')
|
||||
node.inputs[0].link = toLinkId(99)
|
||||
graph.add(node)
|
||||
const dto = new ExecutableNodeDTO(node, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveInput(0)).toThrow('No link found in parent graph')
|
||||
})
|
||||
|
||||
it('should throw when an input link points at a missing source node', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('Target')
|
||||
node.id = toNodeId(2)
|
||||
node.addInput('in', 'IMAGE')
|
||||
graph.add(node)
|
||||
const link = new LLink(toLinkId(1), 'IMAGE', '404', 0, '2', 0)
|
||||
graph.links.set(link.id, link)
|
||||
node.inputs[0].link = link.id
|
||||
const dto = new ExecutableNodeDTO(node, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveInput(0)).toThrow('No input node found')
|
||||
})
|
||||
|
||||
it('should throw when an input source has no DTO', () => {
|
||||
const graph = new LGraph()
|
||||
const source = new LGraphNode('Source')
|
||||
source.addOutput('out', 'IMAGE')
|
||||
graph.add(source)
|
||||
const target = new LGraphNode('Target')
|
||||
target.addInput('in', 'IMAGE')
|
||||
graph.add(target)
|
||||
source.connect(0, target, 0)
|
||||
|
||||
const dto = new ExecutableNodeDTO(target, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveInput(0)).toThrow('No output node DTO found')
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExecutableNodeDTO Output Resolution', () => {
|
||||
@@ -257,6 +334,34 @@ describe('ExecutableNodeDTO Output Resolution', () => {
|
||||
expect(resolved?.node).toBe(dto)
|
||||
expect(resolved?.origin_slot).toBe(0)
|
||||
})
|
||||
|
||||
it('should throw when resolving a repeated output path', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('Looped')
|
||||
node.id = toNodeId(9)
|
||||
node.title = 'Loop title'
|
||||
node.addOutput('out', 'IMAGE')
|
||||
graph.add(node)
|
||||
const dto = new ExecutableNodeDTO(node, ['parent'], new Map(), undefined)
|
||||
|
||||
expect(() =>
|
||||
dto.resolveOutput(0, 'IMAGE', new Set([`undefined:${node.id}[O]0`]))
|
||||
).toThrow('Circular reference detected while resolving output 0')
|
||||
})
|
||||
|
||||
it('should report repeated root outputs without title or path details', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('')
|
||||
node.id = toNodeId(9)
|
||||
node.title = ''
|
||||
node.addOutput('out', 'IMAGE')
|
||||
graph.add(node)
|
||||
const dto = new ExecutableNodeDTO(node, [], new Map(), undefined)
|
||||
|
||||
expect(() =>
|
||||
dto.resolveOutput(0, 'IMAGE', new Set([`undefined:${node.id}[O]0`]))
|
||||
).toThrow('Circular reference detected while resolving output 0 of node 9')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Muted node output resolution', () => {
|
||||
@@ -368,6 +473,135 @@ describe('Bypass node output resolution', () => {
|
||||
expect(resolved).toBeDefined()
|
||||
expect(resolved?.node).toBe(upstreamDto)
|
||||
})
|
||||
|
||||
it('should use the first input when bypassing an any-type output', () => {
|
||||
const graph = new LGraph()
|
||||
|
||||
const upstreamNode = new LGraphNode('Upstream')
|
||||
upstreamNode.addOutput('out', 'IMAGE')
|
||||
graph.add(upstreamNode)
|
||||
|
||||
const bypassedNode = new LGraphNode('Bypassed')
|
||||
bypassedNode.addInput('fallback', 'IMAGE')
|
||||
bypassedNode.addOutput('first', 'IMAGE')
|
||||
bypassedNode.addOutput('second', 'IMAGE')
|
||||
bypassedNode.mode = LGraphEventMode.BYPASS
|
||||
graph.add(bypassedNode)
|
||||
|
||||
upstreamNode.connect(0, bypassedNode, 0)
|
||||
|
||||
const nodeDtoMap = new Map()
|
||||
const upstreamDto = new ExecutableNodeDTO(
|
||||
upstreamNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(upstreamDto.id, upstreamDto)
|
||||
|
||||
const bypassedDto = new ExecutableNodeDTO(
|
||||
bypassedNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(bypassedDto.id, bypassedDto)
|
||||
|
||||
const resolved = bypassedDto.resolveOutput(1, '*', new Set())
|
||||
expect(resolved?.node).toBe(upstreamDto)
|
||||
})
|
||||
|
||||
it('should use the same slot when bypassing an empty-type output', () => {
|
||||
const graph = new LGraph()
|
||||
|
||||
const upstreamNode = new LGraphNode('Upstream')
|
||||
upstreamNode.addOutput('out', 'IMAGE')
|
||||
graph.add(upstreamNode)
|
||||
|
||||
const bypassedNode = new LGraphNode('Bypassed')
|
||||
bypassedNode.addInput('image', 'IMAGE')
|
||||
bypassedNode.addOutput('out', 'IMAGE')
|
||||
bypassedNode.mode = LGraphEventMode.BYPASS
|
||||
graph.add(bypassedNode)
|
||||
|
||||
upstreamNode.connect(0, bypassedNode, 0)
|
||||
|
||||
const nodeDtoMap = new Map()
|
||||
const upstreamDto = new ExecutableNodeDTO(
|
||||
upstreamNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(upstreamDto.id, upstreamDto)
|
||||
|
||||
const bypassedDto = new ExecutableNodeDTO(
|
||||
bypassedNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(bypassedDto.id, bypassedDto)
|
||||
|
||||
const resolved = bypassedDto.resolveOutput(0, '', new Set())
|
||||
expect(resolved?.node).toBe(upstreamDto)
|
||||
})
|
||||
|
||||
it('should use an exact matching input when bypassing different slot types', () => {
|
||||
const graph = new LGraph()
|
||||
|
||||
const upstreamNode = new LGraphNode('Upstream')
|
||||
upstreamNode.addOutput('out', 'IMAGE')
|
||||
graph.add(upstreamNode)
|
||||
|
||||
const bypassedNode = new LGraphNode('Bypassed')
|
||||
bypassedNode.addInput('string', 'STRING')
|
||||
bypassedNode.addInput('image', 'IMAGE')
|
||||
bypassedNode.addOutput('latent', 'LATENT')
|
||||
bypassedNode.mode = LGraphEventMode.BYPASS
|
||||
graph.add(bypassedNode)
|
||||
|
||||
upstreamNode.connect(0, bypassedNode, 1)
|
||||
|
||||
const nodeDtoMap = new Map()
|
||||
const upstreamDto = new ExecutableNodeDTO(
|
||||
upstreamNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(upstreamDto.id, upstreamDto)
|
||||
|
||||
const bypassedDto = new ExecutableNodeDTO(
|
||||
bypassedNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(bypassedDto.id, bypassedDto)
|
||||
|
||||
const resolved = bypassedDto.resolveOutput(0, 'IMAGE', new Set())
|
||||
expect(resolved?.node).toBe(upstreamDto)
|
||||
})
|
||||
|
||||
it('should return undefined when no bypass input matches', () => {
|
||||
const graph = new LGraph()
|
||||
const bypassedNode = new LGraphNode('Bypassed')
|
||||
bypassedNode.addInput('string', 'STRING')
|
||||
bypassedNode.addOutput('out', 'LATENT')
|
||||
bypassedNode.mode = LGraphEventMode.BYPASS
|
||||
graph.add(bypassedNode)
|
||||
const dto = new ExecutableNodeDTO(bypassedNode, [], new Map(), undefined)
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
|
||||
const resolved = dto.resolveOutput(0, 'IMAGE', new Set())
|
||||
|
||||
expect(resolved).toBeUndefined()
|
||||
expect(console.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('No input types match'),
|
||||
dto
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ALWAYS mode node output resolution', () => {
|
||||
@@ -483,6 +717,94 @@ describe('Virtual node resolveVirtualOutput', () => {
|
||||
expect(resolved).toBeUndefined()
|
||||
expect(spy).toHaveBeenCalledWith(0)
|
||||
})
|
||||
|
||||
it('should resolve through a virtual input link', () => {
|
||||
const graph = new LGraph()
|
||||
|
||||
const sourceNode = new LGraphNode('Source')
|
||||
sourceNode.addOutput('out', 'IMAGE')
|
||||
graph.add(sourceNode)
|
||||
|
||||
const passthroughNode = new LGraphNode('Passthrough')
|
||||
passthroughNode.addInput('in', 'IMAGE')
|
||||
graph.add(passthroughNode)
|
||||
sourceNode.connect(0, passthroughNode, 0)
|
||||
|
||||
const virtualNode = new LGraphNode('Virtual Get')
|
||||
virtualNode.addOutput('out', 'IMAGE')
|
||||
virtualNode.isVirtualNode = true
|
||||
virtualNode.resolveVirtualOutput = () => undefined
|
||||
graph.add(virtualNode)
|
||||
vi.spyOn(virtualNode, 'getInputLink').mockReturnValue({
|
||||
target_slot: 0,
|
||||
resolve: () => ({ inputNode: passthroughNode })
|
||||
} as unknown as LLink)
|
||||
|
||||
const nodeDtoMap = new Map()
|
||||
const sourceDto = new ExecutableNodeDTO(
|
||||
sourceNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(sourceDto.id, sourceDto)
|
||||
const passthroughDto = new ExecutableNodeDTO(
|
||||
passthroughNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
nodeDtoMap.set(passthroughDto.id, passthroughDto)
|
||||
const virtualDto = new ExecutableNodeDTO(
|
||||
virtualNode,
|
||||
[],
|
||||
nodeDtoMap,
|
||||
undefined
|
||||
)
|
||||
|
||||
const resolved = virtualDto.resolveOutput(0, 'IMAGE', new Set())
|
||||
expect(resolved?.node).toBe(sourceDto)
|
||||
})
|
||||
|
||||
it('should throw when a virtual input link has no parent node', () => {
|
||||
const graph = new LGraph()
|
||||
const virtualNode = new LGraphNode('Virtual Get')
|
||||
virtualNode.addOutput('out', 'IMAGE')
|
||||
virtualNode.isVirtualNode = true
|
||||
virtualNode.resolveVirtualOutput = () => undefined
|
||||
graph.add(virtualNode)
|
||||
vi.spyOn(virtualNode, 'getInputLink').mockReturnValue({
|
||||
target_slot: 0,
|
||||
resolve: () => ({ inputNode: undefined })
|
||||
} as unknown as LLink)
|
||||
|
||||
const dto = new ExecutableNodeDTO(virtualNode, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveOutput(0, 'IMAGE', new Set())).toThrow(
|
||||
'Virtual node failed to resolve parent'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw when a virtual input link parent has no DTO', () => {
|
||||
const graph = new LGraph()
|
||||
const sourceNode = new LGraphNode('Source')
|
||||
graph.add(sourceNode)
|
||||
const virtualNode = new LGraphNode('Virtual Get')
|
||||
virtualNode.addOutput('out', 'IMAGE')
|
||||
virtualNode.isVirtualNode = true
|
||||
virtualNode.resolveVirtualOutput = () => undefined
|
||||
graph.add(virtualNode)
|
||||
vi.spyOn(virtualNode, 'getInputLink').mockReturnValue({
|
||||
target_slot: 0,
|
||||
resolve: () => ({ inputNode: sourceNode })
|
||||
} as unknown as LLink)
|
||||
|
||||
const dto = new ExecutableNodeDTO(virtualNode, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveOutput(0, 'IMAGE', new Set())).toThrow(
|
||||
'No input node DTO found'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExecutableNodeDTO Properties', () => {
|
||||
@@ -588,6 +910,23 @@ describe('ExecutableNodeDTO Memory Efficiency', () => {
|
||||
})
|
||||
|
||||
describe('ExecutableNodeDTO Integration', () => {
|
||||
it('should delegate getInnerNodes for subgraph nodes', () => {
|
||||
const subgraph = createTestSubgraph({ nodeCount: 2 })
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const executableNodes = new Map()
|
||||
const dto = new ExecutableNodeDTO(
|
||||
subgraphNode,
|
||||
[],
|
||||
executableNodes,
|
||||
undefined
|
||||
)
|
||||
|
||||
const innerNodes = dto.getInnerNodes()
|
||||
|
||||
expect(innerNodes).toHaveLength(2)
|
||||
expect(innerNodes[0]).toBeInstanceOf(ExecutableNodeDTO)
|
||||
})
|
||||
|
||||
it('should work with SubgraphNode flattening', () => {
|
||||
const subgraph = createTestSubgraph({ nodeCount: 3 })
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
@@ -660,6 +999,65 @@ describe('ExecutableNodeDTO Integration', () => {
|
||||
expect(Number(dto.node.id)).toBe(55) // Original node ID preserved
|
||||
expect(Number(dto.subgraphNode?.id)).toBe(99) // Subgraph context
|
||||
})
|
||||
|
||||
it('should throw when a subgraph output slot is missing', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const dto = new ExecutableNodeDTO(subgraphNode, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveOutput(0, 'IMAGE', new Set())).toThrow(
|
||||
'No output found for flattened id'
|
||||
)
|
||||
})
|
||||
|
||||
it('should return undefined when a subgraph output has no inner link', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'out', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.spyOn(subgraphNode, 'resolveSubgraphOutputLink').mockReturnValue(
|
||||
undefined
|
||||
)
|
||||
const dto = new ExecutableNodeDTO(subgraphNode, [], new Map(), undefined)
|
||||
|
||||
const resolved = dto.resolveOutput(0, 'IMAGE', new Set())
|
||||
|
||||
expect(resolved).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should throw when a subgraph output link has no inner node', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'out', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.spyOn(subgraphNode, 'resolveSubgraphOutputLink').mockReturnValue({
|
||||
outputNode: undefined,
|
||||
link: new LLink(toLinkId(1), 'IMAGE', '1', 0, '2', 0)
|
||||
} as never)
|
||||
const dto = new ExecutableNodeDTO(subgraphNode, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveOutput(0, 'IMAGE', new Set())).toThrow(
|
||||
'No output node found'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw when a subgraph output inner node has no DTO', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'out', type: 'IMAGE' }],
|
||||
nodeCount: 1
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const innerNode = subgraph.nodes[0]
|
||||
vi.spyOn(subgraphNode, 'resolveSubgraphOutputLink').mockReturnValue({
|
||||
outputNode: innerNode,
|
||||
link: new LLink(toLinkId(1), 'IMAGE', String(innerNode.id), 0, '2', 0)
|
||||
} as never)
|
||||
const dto = new ExecutableNodeDTO(subgraphNode, [], new Map(), undefined)
|
||||
|
||||
expect(() => dto.resolveOutput(0, 'IMAGE', new Set())).toThrow(
|
||||
'No inner node DTO found'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExecutableNodeDTO Scale Testing', () => {
|
||||
|
||||
277
src/lib/litegraph/src/subgraph/SubgraphIONodeBase.test.ts
Normal file
277
src/lib/litegraph/src/subgraph/SubgraphIONodeBase.test.ts
Normal file
@@ -0,0 +1,277 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { Rectangle } from '@/lib/litegraph/src/infrastructure/Rectangle'
|
||||
import type { DefaultConnectionColors } from '@/lib/litegraph/src/interfaces'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { CanvasPointerEvent } from '@/lib/litegraph/src/litegraph'
|
||||
import { CanvasItem } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import type { Subgraph } from '@/lib/litegraph/src/subgraph/Subgraph'
|
||||
import type { SubgraphInput } from '@/lib/litegraph/src/subgraph/SubgraphInput'
|
||||
import { SubgraphIONodeBase } from '@/lib/litegraph/src/subgraph/SubgraphIONodeBase'
|
||||
import type { NodeId } from '@/types/nodeId'
|
||||
|
||||
type MenuConfig = {
|
||||
title?: string
|
||||
callback?: (item: { content: string; value: string }) => void
|
||||
}
|
||||
|
||||
const { contextMenus, MockContextMenu } = vi.hoisted(() => {
|
||||
const contextMenus: Array<{
|
||||
options: unknown[]
|
||||
config: MenuConfig
|
||||
}> = []
|
||||
|
||||
class MockContextMenu {
|
||||
constructor(options: unknown[], config: MenuConfig) {
|
||||
contextMenus.push({ options, config })
|
||||
}
|
||||
}
|
||||
|
||||
return { contextMenus, MockContextMenu }
|
||||
})
|
||||
|
||||
type TestSlot = SubgraphInput & {
|
||||
arrange: ReturnType<typeof vi.fn>
|
||||
disconnect: ReturnType<typeof vi.fn>
|
||||
draw: ReturnType<typeof vi.fn>
|
||||
measure: ReturnType<typeof vi.fn>
|
||||
onPointerMove: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
class TestIONode extends SubgraphIONodeBase<SubgraphInput> {
|
||||
readonly id = 'subgraph-io' as NodeId
|
||||
readonly emptySlot: SubgraphInput
|
||||
readonly slots: SubgraphInput[]
|
||||
readonly renameSlot = vi.fn()
|
||||
readonly removeSlot = vi.fn()
|
||||
|
||||
constructor(
|
||||
subgraph: Subgraph,
|
||||
slots: SubgraphInput[],
|
||||
emptySlot: SubgraphInput
|
||||
) {
|
||||
super(subgraph)
|
||||
this.slots = slots
|
||||
this.emptySlot = emptySlot
|
||||
}
|
||||
|
||||
get allSlots(): SubgraphInput[] {
|
||||
return [...this.slots, this.emptySlot]
|
||||
}
|
||||
|
||||
get slotAnchorX(): number {
|
||||
return this.pos[0] + this.size[0] - SubgraphIONodeBase.roundedRadius
|
||||
}
|
||||
|
||||
onPointerDown(): void {}
|
||||
|
||||
openMenu(slot: SubgraphInput, event: CanvasPointerEvent): void {
|
||||
this.showSlotContextMenu(slot, event)
|
||||
}
|
||||
|
||||
renameByDoubleClick(slot: SubgraphInput, event: CanvasPointerEvent): void {
|
||||
this.handleSlotDoubleClick(slot, event)
|
||||
}
|
||||
|
||||
drawProtected(
|
||||
ctx: CanvasRenderingContext2D,
|
||||
colorContext: DefaultConnectionColors,
|
||||
fromSlot?: SubgraphInput,
|
||||
editorAlpha?: number
|
||||
): void {
|
||||
ctx.lineWidth = 99
|
||||
ctx.strokeStyle = 'red'
|
||||
ctx.fillStyle = 'blue'
|
||||
ctx.font = '20px serif'
|
||||
ctx.textBaseline = 'top'
|
||||
this.drawSlots(ctx, colorContext, fromSlot, editorAlpha)
|
||||
}
|
||||
}
|
||||
|
||||
function createSlot(
|
||||
name: string,
|
||||
rect: [number, number, number, number],
|
||||
links: number[] = []
|
||||
): TestSlot {
|
||||
const slot = {
|
||||
name,
|
||||
displayName: `${name} label`,
|
||||
linkIds: links,
|
||||
boundingRect: new Rectangle(...rect),
|
||||
isPointerOver: false,
|
||||
measure: vi.fn(() => [rect[2], rect[3]]),
|
||||
arrange: vi.fn((nextRect: [number, number, number, number]) => {
|
||||
slot.boundingRect.set(nextRect)
|
||||
}),
|
||||
onPointerMove: vi.fn((event: CanvasPointerEvent) => {
|
||||
slot.isPointerOver = slot.boundingRect.containsXy(
|
||||
event.canvasX,
|
||||
event.canvasY
|
||||
)
|
||||
}),
|
||||
disconnect: vi.fn(),
|
||||
draw: vi.fn()
|
||||
}
|
||||
return slot as unknown as TestSlot
|
||||
}
|
||||
|
||||
function createSubgraph() {
|
||||
const prompt = vi.fn(
|
||||
(_title: string, _value: string, callback: (value: string) => void) =>
|
||||
callback('renamed')
|
||||
)
|
||||
return {
|
||||
prompt,
|
||||
subgraph: {
|
||||
setDirtyCanvas: vi.fn(),
|
||||
canvasAction: vi.fn(
|
||||
(callback: (canvas: { prompt: typeof prompt }) => void) =>
|
||||
callback({ prompt })
|
||||
)
|
||||
} as unknown as Subgraph
|
||||
}
|
||||
}
|
||||
|
||||
function createNode() {
|
||||
const filled = createSlot('value', [20, 30, 80, 20], [1])
|
||||
const empty = createSlot('', [20, 60, 80, 20])
|
||||
const { subgraph, prompt } = createSubgraph()
|
||||
const node = new TestIONode(subgraph, [filled], empty)
|
||||
node.configure({
|
||||
id: 'subgraph-io',
|
||||
bounding: [10, 20, 100, 80],
|
||||
pinned: false
|
||||
})
|
||||
return { node, filled, empty, subgraph, prompt }
|
||||
}
|
||||
|
||||
function eventAt(x: number, y: number): CanvasPointerEvent {
|
||||
return { canvasX: x, canvasY: y } as CanvasPointerEvent
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
contextMenus.length = 0
|
||||
Object.assign(LiteGraph, { ContextMenu: MockContextMenu })
|
||||
})
|
||||
|
||||
describe('SubgraphIONodeBase', () => {
|
||||
it('moves, snaps, hit-tests, and serializes node bounds', () => {
|
||||
const { node } = createNode()
|
||||
|
||||
node.move(5, -10)
|
||||
|
||||
expect(Array.from(node.pos)).toEqual([15, 10])
|
||||
expect(node.containsPoint([20, 20])).toBe(true)
|
||||
expect(node.asSerialisable()).toEqual({
|
||||
id: 'subgraph-io',
|
||||
bounding: [15, 10, 100, 80],
|
||||
pinned: undefined
|
||||
})
|
||||
|
||||
node.pinned = true
|
||||
expect(node.snapToGrid(10)).toBe(false)
|
||||
expect(node.asSerialisable().pinned).toBe(true)
|
||||
})
|
||||
|
||||
it('tracks pointer entry, slot hover, and pointer leave', () => {
|
||||
const { node, filled } = createNode()
|
||||
|
||||
const overResult = node.onPointerMove(eventAt(25, 35))
|
||||
|
||||
expect(overResult & CanvasItem.SubgraphIoNode).toBeTruthy()
|
||||
expect(overResult & CanvasItem.SubgraphIoSlot).toBeTruthy()
|
||||
expect(node.isPointerOver).toBe(true)
|
||||
expect(filled.isPointerOver).toBe(true)
|
||||
|
||||
const outResult = node.onPointerMove(eventAt(500, 500))
|
||||
|
||||
expect(outResult).toBe(CanvasItem.Nothing)
|
||||
expect(node.isPointerOver).toBe(false)
|
||||
expect(filled.isPointerOver).toBe(false)
|
||||
})
|
||||
|
||||
it('finds slots, arranges them, and restores drawing context state', () => {
|
||||
const { node, filled } = createNode()
|
||||
const ctx = {
|
||||
lineWidth: 1,
|
||||
strokeStyle: 'black',
|
||||
fillStyle: 'white',
|
||||
font: '12px sans-serif',
|
||||
textBaseline: 'middle'
|
||||
} as CanvasRenderingContext2D
|
||||
|
||||
node.arrange()
|
||||
node.draw(ctx, {} as DefaultConnectionColors, filled)
|
||||
|
||||
expect(node.getSlotInPosition(100, 40)).toBe(filled)
|
||||
expect(node.getSlotInPosition(500, 500)).toBeUndefined()
|
||||
expect(filled.arrange).toHaveBeenCalled()
|
||||
expect(node.size[0]).toBeGreaterThanOrEqual(108)
|
||||
expect(ctx.lineWidth).toBe(1)
|
||||
expect(ctx.strokeStyle).toBe('black')
|
||||
expect(ctx.fillStyle).toBe('white')
|
||||
expect(ctx.font).toBe('12px sans-serif')
|
||||
expect(ctx.textBaseline).toBe('middle')
|
||||
expect(filled.draw).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ ctx, fromSlot: filled })
|
||||
)
|
||||
})
|
||||
|
||||
it('prompts for non-empty slot rename on double click', () => {
|
||||
const { node, filled, empty, prompt } = createNode()
|
||||
|
||||
node.renameByDoubleClick(empty, eventAt(0, 0))
|
||||
expect(prompt).not.toHaveBeenCalled()
|
||||
|
||||
node.renameByDoubleClick(filled, eventAt(20, 30))
|
||||
|
||||
expect(prompt).toHaveBeenCalledWith(
|
||||
'Slot name',
|
||||
'value label',
|
||||
expect.any(Function),
|
||||
expect.any(Object)
|
||||
)
|
||||
expect(node.renameSlot).toHaveBeenCalledWith(filled, 'renamed')
|
||||
})
|
||||
|
||||
it('opens slot context menu actions for connected non-empty slots', () => {
|
||||
const { node, filled, subgraph } = createNode()
|
||||
|
||||
node.openMenu(filled, eventAt(20, 30))
|
||||
|
||||
expect(contextMenus).toHaveLength(1)
|
||||
expect(contextMenus[0].config.title).toBe('value')
|
||||
expect(contextMenus[0].options).toMatchObject([
|
||||
{ value: 'disconnect' },
|
||||
{ value: 'rename' },
|
||||
null,
|
||||
{ value: 'remove', className: 'danger' }
|
||||
])
|
||||
|
||||
contextMenus[0].config.callback?.({
|
||||
content: 'Disconnect Links',
|
||||
value: 'disconnect'
|
||||
})
|
||||
contextMenus[0].config.callback?.({
|
||||
content: 'Rename Slot',
|
||||
value: 'rename'
|
||||
})
|
||||
contextMenus[0].config.callback?.({
|
||||
content: 'Remove Slot',
|
||||
value: 'remove'
|
||||
})
|
||||
|
||||
expect(filled.disconnect).toHaveBeenCalled()
|
||||
expect(node.renameSlot).toHaveBeenCalledWith(filled, 'renamed')
|
||||
expect(node.removeSlot).toHaveBeenCalledWith(filled)
|
||||
expect(subgraph.setDirtyCanvas).toHaveBeenCalledWith(true, true)
|
||||
})
|
||||
|
||||
it('does not open a context menu for the empty slot', () => {
|
||||
const { node, empty } = createNode()
|
||||
|
||||
node.openMenu(empty, eventAt(20, 60))
|
||||
|
||||
expect(contextMenus).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
273
src/lib/litegraph/src/subgraph/SubgraphInputNode.test.ts
Normal file
273
src/lib/litegraph/src/subgraph/SubgraphInputNode.test.ts
Normal file
@@ -0,0 +1,273 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { CanvasPointer } from '@/lib/litegraph/src/CanvasPointer'
|
||||
import { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import type {
|
||||
DefaultConnectionColors,
|
||||
INodeInputSlot
|
||||
} from '@/lib/litegraph/src/interfaces'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import type { LinkConnector } from '@/lib/litegraph/src/canvas/LinkConnector'
|
||||
import type { NodeLike } from '@/lib/litegraph/src/types/NodeLike'
|
||||
import type { CanvasPointerEvent } from '@/lib/litegraph/src/types/events'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
|
||||
import { createTestSubgraph } from './__fixtures__/subgraphHelpers'
|
||||
|
||||
function eventAt(x: number, y: number, button = 0): CanvasPointerEvent {
|
||||
return { canvasX: x, canvasY: y, button } as CanvasPointerEvent
|
||||
}
|
||||
|
||||
function createCanvasContext() {
|
||||
return {
|
||||
getTransform: vi.fn(() => new DOMMatrix()),
|
||||
translate: vi.fn(),
|
||||
beginPath: vi.fn(),
|
||||
arc: vi.fn(),
|
||||
moveTo: vi.fn(),
|
||||
lineTo: vi.fn(),
|
||||
stroke: vi.fn(),
|
||||
setTransform: vi.fn(),
|
||||
rect: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillText: vi.fn(),
|
||||
strokeStyle: '',
|
||||
lineWidth: 1,
|
||||
font: '',
|
||||
fillStyle: '',
|
||||
textBaseline: '',
|
||||
globalAlpha: 1
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
}
|
||||
|
||||
describe('SubgraphInputNode', () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('exposes input slots plus the empty slot and computes its anchor', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
subgraph.inputNode.configure({
|
||||
id: subgraph.inputNode.id,
|
||||
bounding: [10, 20, 100, 80],
|
||||
pinned: false
|
||||
})
|
||||
|
||||
expect(subgraph.inputNode.slots).toBe(subgraph.inputs)
|
||||
expect(subgraph.inputNode.allSlots).toEqual([
|
||||
subgraph.inputs[0],
|
||||
subgraph.inputNode.emptySlot
|
||||
])
|
||||
expect(subgraph.inputNode.slotAnchorX).toBe(96)
|
||||
})
|
||||
|
||||
it('sets link connector drag callbacks for left-clicked slots', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.inputs[0]
|
||||
slot.boundingRect.updateTo([10, 20, 100, 30])
|
||||
const pointer = {} as CanvasPointer
|
||||
const linkConnector = {
|
||||
dragNewFromSubgraphInput: vi.fn(),
|
||||
dropLinks: vi.fn(),
|
||||
reset: vi.fn()
|
||||
} as unknown as LinkConnector
|
||||
|
||||
subgraph.inputNode.onPointerDown(eventAt(20, 25), pointer, linkConnector)
|
||||
|
||||
pointer.onDragStart?.(pointer)
|
||||
pointer.onDragEnd?.(eventAt(40, 45))
|
||||
pointer.finally?.()
|
||||
|
||||
expect(linkConnector.dragNewFromSubgraphInput).toHaveBeenCalledWith(
|
||||
subgraph,
|
||||
subgraph.inputNode,
|
||||
slot
|
||||
)
|
||||
expect(linkConnector.dropLinks).toHaveBeenCalledWith(
|
||||
subgraph,
|
||||
expect.objectContaining({ canvasX: 40 })
|
||||
)
|
||||
expect(linkConnector.reset).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('opens the slot context menu for right-clicked slots', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.inputs[0]
|
||||
slot.boundingRect.updateTo([10, 20, 100, 30])
|
||||
const menuSpy = vi.spyOn(
|
||||
subgraph.inputNode as unknown as {
|
||||
showSlotContextMenu(slot: unknown, event: unknown): void
|
||||
},
|
||||
'showSlotContextMenu'
|
||||
)
|
||||
|
||||
subgraph.inputNode.onPointerDown(
|
||||
eventAt(20, 25, 2),
|
||||
{} as CanvasPointer,
|
||||
{} as LinkConnector
|
||||
)
|
||||
subgraph.inputNode.onPointerDown(
|
||||
eventAt(500, 500, 2),
|
||||
{} as CanvasPointer,
|
||||
{} as LinkConnector
|
||||
)
|
||||
|
||||
expect(menuSpy).toHaveBeenCalledOnce()
|
||||
expect(menuSpy).toHaveBeenCalledWith(
|
||||
slot,
|
||||
expect.objectContaining({ button: 2 })
|
||||
)
|
||||
})
|
||||
|
||||
it('renames and removes input slots through the parent subgraph', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.inputs[0]
|
||||
const renameSpy = vi.spyOn(subgraph, 'renameInput')
|
||||
const removeSpy = vi.spyOn(subgraph, 'removeInput')
|
||||
|
||||
subgraph.inputNode.renameSlot(slot, 'preview')
|
||||
subgraph.inputNode.removeSlot(slot)
|
||||
|
||||
expect(renameSpy).toHaveBeenCalledWith(slot, 'preview')
|
||||
expect(removeSpy).toHaveBeenCalledWith(slot)
|
||||
})
|
||||
|
||||
it('delegates connection checks and input-type connections', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.inputs[0]
|
||||
const inputSlot = {
|
||||
index: 0,
|
||||
slot: { name: 'in', type: 'IMAGE' }
|
||||
} as unknown as { index: number; slot: INodeInputSlot }
|
||||
const targetNode = new LGraphNode('Target')
|
||||
targetNode.id = toNodeId(99)
|
||||
vi.spyOn(targetNode, 'findInputByType').mockReturnValue(inputSlot)
|
||||
const link = new LLink(toLinkId(1), 'IMAGE', toNodeId(1), 0, toNodeId(2), 0)
|
||||
const connectSpy = vi.spyOn(slot, 'connect').mockReturnValue(link)
|
||||
const inputNode = fromPartial<NodeLike>({
|
||||
canConnectTo: vi.fn(() => true)
|
||||
})
|
||||
|
||||
expect(
|
||||
subgraph.inputNode.canConnectTo(inputNode, inputSlot.slot, slot)
|
||||
).toBe(true)
|
||||
expect(
|
||||
subgraph.inputNode.connectByType(0, targetNode, 'IMAGE', {
|
||||
afterRerouteId: toRerouteId(7)
|
||||
})
|
||||
).toBe(link)
|
||||
expect(connectSpy).toHaveBeenCalledWith(
|
||||
inputSlot.slot,
|
||||
targetNode,
|
||||
toRerouteId(7)
|
||||
)
|
||||
|
||||
vi.mocked(targetNode.findInputByType).mockReturnValue(undefined)
|
||||
expect(
|
||||
subgraph.inputNode.connectByType(0, targetNode, 'LATENT')
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('finds input slots by name and the first free slot by type', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
{ name: 'used', type: 'IMAGE' },
|
||||
{ name: 'free', type: 'IMAGE' }
|
||||
]
|
||||
})
|
||||
subgraph.inputs[0].linkIds.push(toLinkId(1))
|
||||
|
||||
expect(subgraph.inputNode.findOutputSlot('free')).toBe(subgraph.inputs[1])
|
||||
expect(subgraph.inputNode.findOutputByType('IMAGE')).toBe(
|
||||
subgraph.inputs[0]
|
||||
)
|
||||
expect(subgraph.inputNode.findOutputByType('LATENT')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('disconnects node inputs and clears floating links', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const targetNode = new LGraphNode('Target')
|
||||
targetNode.id = toNodeId(99)
|
||||
const input = targetNode.addInput('image', 'IMAGE')
|
||||
const floatingLink = new LLink(
|
||||
toLinkId(9),
|
||||
'IMAGE',
|
||||
subgraph.inputNode.id,
|
||||
0,
|
||||
targetNode.id,
|
||||
0
|
||||
)
|
||||
input._floatingLinks = new Set([floatingLink])
|
||||
input.link = toLinkId(3)
|
||||
const removeFloatingLinkSpy = vi.spyOn(subgraph, 'removeFloatingLink')
|
||||
const setDirtyCanvasSpy = vi.spyOn(subgraph, 'setDirtyCanvas')
|
||||
|
||||
subgraph.inputNode._disconnectNodeInput(targetNode, input, undefined)
|
||||
|
||||
expect(removeFloatingLinkSpy).toHaveBeenCalledWith(floatingLink)
|
||||
expect(input.link).toBeNull()
|
||||
expect(setDirtyCanvasSpy).toHaveBeenCalledWith(false, true)
|
||||
})
|
||||
|
||||
it('draws the side rail and input slots', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
subgraph.inputNode.configure({
|
||||
id: subgraph.inputNode.id,
|
||||
bounding: [10, 20, 100, 80],
|
||||
pinned: false
|
||||
})
|
||||
const ctx = createCanvasContext()
|
||||
const drawSlotsSpy = vi.spyOn(
|
||||
subgraph.inputNode as unknown as {
|
||||
drawSlots(
|
||||
ctx: unknown,
|
||||
colorContext: unknown,
|
||||
fromSlot: unknown,
|
||||
editorAlpha: unknown
|
||||
): void
|
||||
},
|
||||
'drawSlots'
|
||||
)
|
||||
|
||||
subgraph.inputNode.drawProtected(
|
||||
ctx,
|
||||
{
|
||||
getConnectedColor: vi.fn(() => '#fff'),
|
||||
getDisconnectedColor: vi.fn(() => '#000')
|
||||
} as unknown as DefaultConnectionColors,
|
||||
subgraph.inputs[0],
|
||||
0.5
|
||||
)
|
||||
|
||||
expect(ctx.translate).toHaveBeenCalledWith(10, 20)
|
||||
expect(ctx.beginPath).toHaveBeenCalled()
|
||||
expect(ctx.stroke).toHaveBeenCalled()
|
||||
expect(ctx.setTransform).toHaveBeenCalled()
|
||||
expect(drawSlotsSpy).toHaveBeenCalledWith(
|
||||
ctx,
|
||||
expect.objectContaining({
|
||||
getConnectedColor: expect.any(Function),
|
||||
getDisconnectedColor: expect.any(Function)
|
||||
}),
|
||||
subgraph.inputs[0],
|
||||
0.5
|
||||
)
|
||||
})
|
||||
})
|
||||
225
src/lib/litegraph/src/subgraph/SubgraphInputSlot.test.ts
Normal file
225
src/lib/litegraph/src/subgraph/SubgraphInputSlot.test.ts
Normal file
@@ -0,0 +1,225 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import type {
|
||||
INodeInputSlot,
|
||||
INodeOutputSlot
|
||||
} from '@/lib/litegraph/src/interfaces'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
|
||||
import { createTestSubgraph } from './__fixtures__/subgraphHelpers'
|
||||
|
||||
function createWidget(
|
||||
overrides: Partial<Pick<IBaseWidget, 'name' | 'type' | 'options'>> = {}
|
||||
): IBaseWidget {
|
||||
return {
|
||||
name: overrides.name ?? 'strength',
|
||||
type: overrides.type ?? 'FLOAT',
|
||||
options: {
|
||||
min: 0,
|
||||
max: 1,
|
||||
step: 0.1,
|
||||
step2: 0.01,
|
||||
precision: 2,
|
||||
...overrides.options
|
||||
}
|
||||
} as IBaseWidget
|
||||
}
|
||||
|
||||
describe('SubgraphInput', () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('connects subgraph inputs to node inputs', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const targetNode = new LGraphNode('Target')
|
||||
targetNode.id = toNodeId(10)
|
||||
subgraph.add(targetNode)
|
||||
const input = targetNode.addInput('image', 'IMAGE')
|
||||
const afterChangeSpy = vi.spyOn(subgraph, 'afterChange')
|
||||
const triggerSpy = vi.spyOn(subgraph, 'trigger')
|
||||
const connectionSpy = vi.fn()
|
||||
targetNode.onConnectionsChange = connectionSpy
|
||||
|
||||
const link = subgraph.inputs[0].connect(input, targetNode, toRerouteId(5))
|
||||
|
||||
expect(link).toBeInstanceOf(LLink)
|
||||
expect(link?.origin_id).toBe(subgraph.inputNode.id)
|
||||
expect(link?.target_id).toBe(targetNode.id)
|
||||
expect(link?.parentId).toBe(toRerouteId(5))
|
||||
expect(subgraph.inputs[0].linkIds).toEqual([link?.id])
|
||||
expect(input.link).toBe(link?.id)
|
||||
expect(triggerSpy).toHaveBeenCalledWith('node:slot-links:changed', {
|
||||
nodeId: targetNode.id,
|
||||
slotType: NodeSlotType.INPUT,
|
||||
slotIndex: 0,
|
||||
connected: true,
|
||||
linkId: link?.id
|
||||
})
|
||||
expect(connectionSpy).toHaveBeenCalledWith(
|
||||
NodeSlotType.INPUT,
|
||||
0,
|
||||
true,
|
||||
link,
|
||||
input
|
||||
)
|
||||
expect(afterChangeSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not connect when the target node blocks the input', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const targetNode = new LGraphNode('Target')
|
||||
const input = targetNode.addInput('image', 'IMAGE')
|
||||
targetNode.onConnectInput = vi.fn(() => false)
|
||||
|
||||
expect(subgraph.inputs[0].connect(input, targetNode)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('rejects widget inputs that do not match the promoted widget', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'strength', type: 'FLOAT' }]
|
||||
})
|
||||
const targetNode = new LGraphNode('Target')
|
||||
const input = targetNode.addInput('strength', 'FLOAT')
|
||||
const currentWidget = createWidget()
|
||||
const otherWidget = createWidget({ options: { min: 1 } })
|
||||
input.widget = { name: otherWidget.name }
|
||||
targetNode.widgets = [otherWidget]
|
||||
subgraph.inputs[0]._widget = currentWidget
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
|
||||
expect(subgraph.inputs[0].connect(input, targetNode)).toBeUndefined()
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'Target input has invalid widget.',
|
||||
input,
|
||||
targetNode
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks connected widgets and clears them on disconnect', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'strength', type: 'FLOAT' }]
|
||||
})
|
||||
const targetNode = new LGraphNode('Target')
|
||||
targetNode.id = toNodeId(10)
|
||||
subgraph.add(targetNode)
|
||||
const input = targetNode.addInput('strength', 'FLOAT')
|
||||
const widget = createWidget()
|
||||
input.widget = { name: widget.name }
|
||||
targetNode.widgets = [widget]
|
||||
const connectedSpy = vi.fn()
|
||||
const disconnectedSpy = vi.fn()
|
||||
subgraph.inputs[0].events.addEventListener('input-connected', connectedSpy)
|
||||
subgraph.inputs[0].events.addEventListener(
|
||||
'input-disconnected',
|
||||
disconnectedSpy
|
||||
)
|
||||
|
||||
const link = subgraph.inputs[0].connect(input, targetNode)
|
||||
|
||||
expect(subgraph.inputs[0]._widget).toBe(widget)
|
||||
expect(subgraph.inputs[0].getConnectedWidgets()).toEqual([widget])
|
||||
expect(connectedSpy).toHaveBeenCalledOnce()
|
||||
|
||||
subgraph.inputs[0].disconnect()
|
||||
|
||||
expect(subgraph.inputs[0]._widget).toBeUndefined()
|
||||
expect(subgraph.inputs[0].linkIds).toEqual([])
|
||||
expect(disconnectedSpy).toHaveBeenCalledTimes(2)
|
||||
expect(subgraph.getLink(link?.id ?? toLinkId(-1))).toBeUndefined()
|
||||
})
|
||||
|
||||
it('arranges and labels from the right edge', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const input = subgraph.inputs[0]
|
||||
|
||||
input.arrange([140, 30, 120, 40])
|
||||
|
||||
expect(Array.from(input.boundingRect)).toEqual([20, 30, 120, 40])
|
||||
expect(input.pos).toEqual([120, 50])
|
||||
expect(input.labelPos).toEqual([20, 50])
|
||||
})
|
||||
|
||||
it('validates node inputs and subgraph outputs as targets', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'source', type: 'IMAGE' }],
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const input = subgraph.inputs[0]
|
||||
const imageInput = { name: 'image', type: 'IMAGE', link: null }
|
||||
const latentInput = { name: 'latent', type: 'LATENT', link: null }
|
||||
const imageOutput = fromPartial<INodeOutputSlot>({
|
||||
name: 'image',
|
||||
type: 'IMAGE',
|
||||
links: []
|
||||
})
|
||||
|
||||
expect(input.isValidTarget(imageInput as INodeInputSlot)).toBe(true)
|
||||
expect(input.isValidTarget(latentInput as INodeInputSlot)).toBe(false)
|
||||
expect(input.isValidTarget(imageOutput)).toBe(false)
|
||||
expect(input.isValidTarget(subgraph.outputs[0])).toBe(true)
|
||||
})
|
||||
|
||||
it('matches widget options by type and numeric constraints', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'strength', type: 'FLOAT' }]
|
||||
})
|
||||
const input = subgraph.inputs[0]
|
||||
input._widget = createWidget()
|
||||
|
||||
expect(input.matchesWidget(createWidget())).toBe(true)
|
||||
expect(input.matchesWidget(createWidget({ type: 'INT' }))).toBe(false)
|
||||
expect(input.matchesWidget(createWidget({ options: { max: 2 } }))).toBe(
|
||||
false
|
||||
)
|
||||
|
||||
input._widget = undefined
|
||||
expect(input.matchesWidget(createWidget({ type: 'INT' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('disconnects node inputs and removes link references', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const targetNode = new LGraphNode('Target')
|
||||
targetNode.id = toNodeId(10)
|
||||
subgraph.add(targetNode)
|
||||
const input = targetNode.addInput('image', 'IMAGE')
|
||||
const link = subgraph.inputs[0].connect(input, targetNode)
|
||||
const triggerSpy = vi.spyOn(subgraph, 'trigger')
|
||||
const connectionSpy = vi.fn()
|
||||
targetNode.onConnectionsChange = connectionSpy
|
||||
|
||||
subgraph.inputNode._disconnectNodeInput(targetNode, input, link)
|
||||
|
||||
expect(input.link).toBeNull()
|
||||
expect(subgraph.inputs[0].linkIds).toEqual([])
|
||||
expect(connectionSpy).toHaveBeenCalledWith(
|
||||
NodeSlotType.INPUT,
|
||||
0,
|
||||
false,
|
||||
link,
|
||||
subgraph.inputs[0]
|
||||
)
|
||||
expect(triggerSpy).toHaveBeenCalledWith('node:slot-links:changed', {
|
||||
nodeId: targetNode.id,
|
||||
slotType: NodeSlotType.INPUT,
|
||||
slotIndex: 0,
|
||||
connected: false,
|
||||
linkId: link?.id
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -6,7 +6,7 @@
|
||||
*/
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
import {
|
||||
@@ -17,9 +17,13 @@ import {
|
||||
SubgraphNode
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { ExportedSubgraphInstance } from '@/lib/litegraph/src/types/serialisation'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import { usePreviewExposureStore } from '@/stores/previewExposureStore'
|
||||
import { useWidgetValueStore } from '@/stores/widgetValueStore'
|
||||
import { createNodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import type { WidgetId } from '@/types/widgetId'
|
||||
|
||||
import { subgraphTest } from './__fixtures__/subgraphFixtures'
|
||||
import {
|
||||
@@ -33,6 +37,10 @@ beforeEach(() => {
|
||||
resetSubgraphFixtureState()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('SubgraphNode Construction', () => {
|
||||
it('should create a SubgraphNode from a subgraph definition', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
@@ -100,6 +108,18 @@ describe('SubgraphNode Construction', () => {
|
||||
expect(subgraphNode.widgets).toEqual([])
|
||||
})
|
||||
|
||||
it('warns when external code assigns widgets directly', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
|
||||
subgraphNode.widgets = []
|
||||
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Cannot manually set widgets on SubgraphNode; use the promotion system.'
|
||||
)
|
||||
})
|
||||
|
||||
subgraphTest(
|
||||
'should synchronize slots with subgraph definition',
|
||||
({ subgraphWithNode }) => {
|
||||
@@ -220,6 +240,38 @@ describe('SubgraphNode Synchronization', () => {
|
||||
expect(subgraphNode.outputs[0].label).toBe('newOutput')
|
||||
})
|
||||
|
||||
it('throws when input rename events reference a missing slot', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'input', type: 'number' }]
|
||||
})
|
||||
createTestSubgraphNode(subgraph)
|
||||
|
||||
expect(() =>
|
||||
subgraph.events.dispatch('renaming-input', {
|
||||
input: subgraph.inputs[0],
|
||||
index: 99,
|
||||
oldName: 'input',
|
||||
newName: 'missing'
|
||||
})
|
||||
).toThrow('Subgraph input not found')
|
||||
})
|
||||
|
||||
it('throws when output rename events reference a missing slot', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'output', type: 'number' }]
|
||||
})
|
||||
createTestSubgraphNode(subgraph)
|
||||
|
||||
expect(() =>
|
||||
subgraph.events.dispatch('renaming-output', {
|
||||
output: subgraph.outputs[0],
|
||||
index: 99,
|
||||
oldName: 'output',
|
||||
newName: 'missing'
|
||||
})
|
||||
).toThrow('Subgraph output not found')
|
||||
})
|
||||
|
||||
it('represents promoted host widgets by input widgetId and WidgetState', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'text', type: 'STRING' }]
|
||||
@@ -362,6 +414,41 @@ describe('SubgraphNode Synchronization', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back projected widget fields when WidgetState is missing', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'text', type: 'STRING' }]
|
||||
})
|
||||
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
const input = interiorNode.addInput('value', 'STRING')
|
||||
input.widget = { name: 'value' }
|
||||
interiorNode.addOutput('out', 'STRING')
|
||||
interiorNode.addWidget('text', 'value', 'initial', () => {})
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const promotedInput = subgraphNode.inputs[0]
|
||||
const widget = subgraphNode.widgets[0]
|
||||
const id = promotedInput.widgetId
|
||||
if (!id) throw new Error('Missing widgetId')
|
||||
if (!widget) throw new Error('Missing projected widget')
|
||||
|
||||
useWidgetValueStore().deleteWidget(id)
|
||||
|
||||
expect(widget.name).toBe('text')
|
||||
expect(widget.label).toBe('text')
|
||||
expect(widget.y).toBe(0)
|
||||
expect(widget.type).toBe('text')
|
||||
expect(widget.options).toEqual({})
|
||||
expect(widget.value).toBeUndefined()
|
||||
expect(() => {
|
||||
widget.label = 'Label'
|
||||
widget.y = 12
|
||||
widget.callback?.('updated')
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should keep input.widget.name stable after rename (onGraphConfigured safety)', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'text', type: 'STRING' }]
|
||||
@@ -443,6 +530,111 @@ describe('SubgraphNode Synchronization', () => {
|
||||
'My Seed'
|
||||
)
|
||||
})
|
||||
|
||||
it('keeps rename behavior when widget state has been removed', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'text', type: 'STRING' }]
|
||||
})
|
||||
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
const input = interiorNode.addInput('value', 'STRING')
|
||||
input.widget = { name: 'value' }
|
||||
interiorNode.addWidget('text', 'value', 'initial', () => {})
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const promotedInput = subgraphNode.inputs[0]
|
||||
const widgetId = promotedInput.widgetId
|
||||
if (!widgetId) throw new Error('Missing widgetId')
|
||||
useWidgetValueStore().deleteWidget(widgetId)
|
||||
|
||||
subgraph.renameInput(subgraph.inputs[0], 'Renamed Text')
|
||||
|
||||
expect(promotedInput.label).toBe('Renamed Text')
|
||||
expect(useWidgetValueStore().getWidget(widgetId)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('rebinds promoted widgets when subgraph input objects are recreated', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'text', type: 'STRING' }]
|
||||
})
|
||||
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
interiorNode.id = toNodeId(5)
|
||||
const input = interiorNode.addInput('value', 'STRING')
|
||||
input.widget = { name: 'value' }
|
||||
interiorNode.addWidget('text', 'value', 'initial', () => {})
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const originalSlot = subgraphNode.inputs[0]._subgraphSlot
|
||||
const originalWidgetId = subgraphNode.inputs[0].widgetId
|
||||
const serialized = subgraph.asSerialisable()
|
||||
|
||||
subgraph.configure(serialized)
|
||||
|
||||
expect(subgraphNode.inputs).toHaveLength(1)
|
||||
expect(subgraphNode.inputs[0]._subgraphSlot).toBe(subgraph.inputs[0])
|
||||
expect(subgraphNode.inputs[0]._subgraphSlot).not.toBe(originalSlot)
|
||||
expect(subgraphNode.inputs[0].widgetId).toBe(originalWidgetId)
|
||||
expect(subgraphNode.widgets[0]).toMatchObject({
|
||||
name: 'text',
|
||||
value: 'initial'
|
||||
})
|
||||
})
|
||||
|
||||
it('stores DOM widget metadata from custom promoted host widgets', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'dom', type: 'STRING' }]
|
||||
})
|
||||
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
const input = interiorNode.addInput('value', 'STRING')
|
||||
input.widget = { name: 'value' }
|
||||
const interiorWidget = interiorNode.addWidget(
|
||||
'text',
|
||||
'value',
|
||||
'initial',
|
||||
() => {}
|
||||
)
|
||||
Object.assign(interiorWidget, { isDOMWidget: true })
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
const hostWidget = fromPartial<IBaseWidget>({
|
||||
name: 'host',
|
||||
type: 'text',
|
||||
value: 'host value',
|
||||
options: {},
|
||||
y: 0
|
||||
})
|
||||
|
||||
class HostWidgetSubgraphNode extends SubgraphNode {
|
||||
protected override createPromotedHostWidget() {
|
||||
return hostWidget
|
||||
}
|
||||
}
|
||||
|
||||
const subgraphNode = new HostWidgetSubgraphNode(
|
||||
subgraph.rootGraph,
|
||||
subgraph,
|
||||
fromPartial<ExportedSubgraphInstance>({
|
||||
id: 10,
|
||||
type: subgraph.id,
|
||||
pos: [0, 0],
|
||||
size: [200, 100],
|
||||
properties: {}
|
||||
})
|
||||
)
|
||||
const widgetId = subgraphNode.inputs[0].widgetId
|
||||
if (!widgetId) throw new Error('Missing widgetId')
|
||||
|
||||
expect(subgraphNode.widgets).toEqual([hostWidget])
|
||||
expect(useWidgetValueStore().getWidget(widgetId)).toMatchObject({
|
||||
isDOMWidget: true
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode widget name collision on rename', () => {
|
||||
@@ -658,6 +850,31 @@ describe('SubgraphNode Lifecycle', () => {
|
||||
})
|
||||
|
||||
describe('SubgraphNode Basic Functionality', () => {
|
||||
it('opens subgraphs from the title button and delegates other buttons', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const canvas = fromPartial<
|
||||
Parameters<SubgraphNode['onTitleButtonClick']>[1]
|
||||
>({
|
||||
openSubgraph: vi.fn()
|
||||
})
|
||||
const fallback = vi
|
||||
.spyOn(LGraphNode.prototype, 'onTitleButtonClick')
|
||||
.mockImplementation(() => undefined)
|
||||
|
||||
subgraphNode.onTitleButtonClick(
|
||||
fromPartial({ name: 'enter_subgraph' }),
|
||||
canvas
|
||||
)
|
||||
subgraphNode.onTitleButtonClick(fromPartial({ name: 'other' }), canvas)
|
||||
|
||||
expect(canvas.openSubgraph).toHaveBeenCalledWith(subgraph, subgraphNode)
|
||||
expect(fallback).toHaveBeenCalledWith(
|
||||
fromPartial({ name: 'other' }),
|
||||
canvas
|
||||
)
|
||||
})
|
||||
|
||||
it('should inherit input types correctly', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
@@ -687,6 +904,157 @@ describe('SubgraphNode Basic Functionality', () => {
|
||||
expect(subgraphNode.outputs[1].type).toBe('string')
|
||||
expect(subgraphNode.outputs[2].type).toBe('*')
|
||||
})
|
||||
|
||||
it('delegates title box drawing to a custom handler', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const onDrawTitleBox = vi.fn()
|
||||
subgraphNode.onDrawTitleBox = onDrawTitleBox
|
||||
const ctx = fromPartial<CanvasRenderingContext2D>({})
|
||||
|
||||
subgraphNode.drawTitleBox(ctx, {
|
||||
scale: 2,
|
||||
low_quality: false,
|
||||
title_height: 30,
|
||||
box_size: 12
|
||||
})
|
||||
|
||||
expect(onDrawTitleBox).toHaveBeenCalledWith(
|
||||
ctx,
|
||||
30,
|
||||
subgraphNode.renderingSize,
|
||||
2
|
||||
)
|
||||
})
|
||||
|
||||
it('draws the default title box with and without the bitmap icon', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const ctx = fromPartial<CanvasRenderingContext2D>({
|
||||
save: vi.fn(),
|
||||
beginPath: vi.fn(),
|
||||
roundRect: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
translate: vi.fn(),
|
||||
scale: vi.fn(),
|
||||
drawImage: vi.fn(),
|
||||
restore: vi.fn()
|
||||
})
|
||||
|
||||
subgraphNode.drawTitleBox(ctx, { scale: 1 })
|
||||
subgraphNode.drawTitleBox(ctx, { scale: 1, low_quality: true })
|
||||
|
||||
expect(ctx.roundRect).toHaveBeenCalledWith(6, -24.5, 22, 20, 5)
|
||||
expect(ctx.drawImage).toHaveBeenCalledTimes(1)
|
||||
expect(ctx.restore).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('returns undefined when a widgetId does not match a promoted input', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'text', type: 'STRING' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
|
||||
expect(
|
||||
subgraphNode.getSlotFromWidget(
|
||||
fromPartial<IBaseWidget>({
|
||||
name: 'missing',
|
||||
type: 'text',
|
||||
value: '',
|
||||
widgetId: 'missing-widget' as WidgetId
|
||||
})
|
||||
)
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns null for missing inner input links', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'output', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
|
||||
expect(subgraphNode.getInputLink(0)).toBeNull()
|
||||
})
|
||||
|
||||
it('returns a translated input link for connected subgraph outputs', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'output', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const inner = new LGraphNode('Inner')
|
||||
inner.id = toNodeId(9)
|
||||
inner.addOutput('image', 'IMAGE')
|
||||
subgraph.add(inner)
|
||||
subgraph.outputNode.slots[0].connect(inner.outputs[0], inner)
|
||||
|
||||
const link = subgraphNode.getInputLink(0)
|
||||
|
||||
expect(link?.origin_id).toBe(toNodeId(`${subgraphNode.id}:${inner.id}`))
|
||||
expect(link?.origin_slot).toBe(0)
|
||||
})
|
||||
|
||||
it('returns empty resolved input links when the subgraph input is isolated', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'input', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
|
||||
expect(subgraphNode.resolveSubgraphInputLinks(0)).toEqual([])
|
||||
})
|
||||
|
||||
it('returns resolved input links when the subgraph input is connected', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'input', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const inner = new LGraphNode('Inner')
|
||||
inner.id = toNodeId(9)
|
||||
const input = inner.addInput('image', 'IMAGE')
|
||||
subgraph.add(inner)
|
||||
subgraph.inputNode.slots[0].connect(input, inner)
|
||||
|
||||
expect(subgraphNode.resolveSubgraphInputLinks(0)).toEqual([
|
||||
expect.objectContaining({
|
||||
input,
|
||||
inputNode: inner
|
||||
})
|
||||
])
|
||||
})
|
||||
|
||||
it('returns resolved output links when the subgraph output is connected', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'output', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const inner = new LGraphNode('Inner')
|
||||
inner.addOutput('image', 'IMAGE')
|
||||
subgraph.add(inner)
|
||||
subgraph.outputNode.slots[0].connect(inner.outputs[0], inner)
|
||||
|
||||
expect(subgraphNode.resolveSubgraphOutputLink(0)?.outputNode).toBe(inner)
|
||||
})
|
||||
|
||||
it('returns a consistent slot shape only when all inner shapes match', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'input', type: 'IMAGE' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const slot = subgraph.inputs[0]
|
||||
|
||||
expect(subgraphNode.getSlotShape(slot, fromPartial({ shape: 4 }))).toBe(4)
|
||||
|
||||
const node = new LGraphNode('ShapeTarget')
|
||||
const rounded = node.addInput('rounded', 'IMAGE')
|
||||
const boxed = node.addInput('boxed', 'IMAGE')
|
||||
rounded.shape = 4
|
||||
boxed.shape = 3
|
||||
subgraph.add(node)
|
||||
slot.connect(rounded, node)
|
||||
|
||||
expect(subgraphNode.getSlotShape(slot, boxed)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode Execution', () => {
|
||||
@@ -776,6 +1144,27 @@ describe('SubgraphNode Execution', () => {
|
||||
expect(() => subgraph.add(subgraphNode)).toThrow()
|
||||
})
|
||||
|
||||
it('throws a recursion error when traversal revisits the same subgraph node', () => {
|
||||
const subgraph = createTestSubgraph({ name: '' })
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
subgraphNode.title = 'Recursive Host'
|
||||
|
||||
expect(() =>
|
||||
subgraphNode.getInnerNodes(new Map(), [], [], new Set([subgraphNode]))
|
||||
).toThrow('Circular reference detected')
|
||||
})
|
||||
|
||||
it('describes unnamed recursive subgraph nodes', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
subgraph.name = ''
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
subgraphNode.title = ''
|
||||
|
||||
expect(() =>
|
||||
subgraphNode.getInnerNodes(new Map(), [], [], new Set([subgraphNode]))
|
||||
).toThrow("node 1 of subgraph 'Unnamed Subgraph'")
|
||||
})
|
||||
|
||||
it('should resolve cross-boundary links', () => {
|
||||
// This test verifies that links can cross subgraph boundaries
|
||||
// Currently this is a basic test - full cross-boundary linking
|
||||
@@ -801,6 +1190,171 @@ describe('SubgraphNode Execution', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode preview exposure hydration', () => {
|
||||
it('hydrates explicit preview exposure properties', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const store = usePreviewExposureStore()
|
||||
|
||||
subgraphNode.configure({
|
||||
...subgraphNode.serialize(),
|
||||
properties: {
|
||||
previewExposures: [
|
||||
{
|
||||
name: 'preview',
|
||||
sourceNodeId: '12',
|
||||
sourcePreviewName: '$$preview'
|
||||
}
|
||||
]
|
||||
}
|
||||
} as ExportedSubgraphInstance)
|
||||
|
||||
expect(
|
||||
store.getExposures(subgraphNode.rootGraph.id, String(subgraphNode.id))
|
||||
).toEqual([
|
||||
{
|
||||
name: 'preview',
|
||||
sourceNodeId: toNodeId(12),
|
||||
sourcePreviewName: '$$preview'
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('clears exposures when an explicit empty property is serialized', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const store = usePreviewExposureStore()
|
||||
store.addExposure(subgraphNode.rootGraph.id, String(subgraphNode.id), {
|
||||
sourceNodeId: '12',
|
||||
sourcePreviewName: '$$preview'
|
||||
})
|
||||
|
||||
subgraphNode.configure({
|
||||
...subgraphNode.serialize(),
|
||||
properties: { previewExposures: [] }
|
||||
} as ExportedSubgraphInstance)
|
||||
|
||||
expect(
|
||||
store.getExposures(subgraphNode.rootGraph.id, String(subgraphNode.id))
|
||||
).toEqual([])
|
||||
})
|
||||
|
||||
it('hydrates legacy locator exposures when no explicit property exists', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const store = usePreviewExposureStore()
|
||||
const legacyLocator = createNodeLocatorId(null, subgraphNode.id)
|
||||
store.addExposure(subgraphNode.rootGraph.id, legacyLocator, {
|
||||
sourceNodeId: '12',
|
||||
sourcePreviewName: '$$legacy'
|
||||
})
|
||||
|
||||
subgraphNode.configure({
|
||||
...subgraphNode.serialize(),
|
||||
properties: {}
|
||||
} as ExportedSubgraphInstance)
|
||||
|
||||
expect(
|
||||
store.getExposures(subgraphNode.rootGraph.id, String(subgraphNode.id))
|
||||
).toEqual([
|
||||
expect.objectContaining({
|
||||
sourceNodeId: toNodeId(12),
|
||||
sourcePreviewName: '$$legacy'
|
||||
})
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode serialization', () => {
|
||||
it('serializes promoted widget values and valid quarantine entries', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'seed', type: 'INT' }]
|
||||
})
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
const input = interiorNode.addInput('value', 'INT')
|
||||
input.widget = { name: 'value' }
|
||||
interiorNode.addWidget('number', 'value', 3, () => {})
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const widgetId = subgraphNode.inputs[0].widgetId
|
||||
if (!widgetId) throw new Error('Missing widgetId')
|
||||
useWidgetValueStore().setValue(widgetId, 42)
|
||||
subgraphNode.properties.proxyWidgetErrorQuarantine = [
|
||||
{
|
||||
originalEntry: ['-1', 'seed'],
|
||||
reason: 'missingSourceNode',
|
||||
attemptedAtVersion: 1,
|
||||
hostValue: 7
|
||||
}
|
||||
]
|
||||
|
||||
const serialized = subgraphNode.serialize()
|
||||
|
||||
expect(serialized.widgets_values).toEqual([42])
|
||||
expect(serialized.properties?.proxyWidgetErrorQuarantine).toEqual([
|
||||
{
|
||||
originalEntry: ['-1', 'seed'],
|
||||
reason: 'missingSourceNode',
|
||||
attemptedAtVersion: 1,
|
||||
hostValue: 7
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('uses quarantined host values before serialized widget values', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'seed', type: 'INT' }]
|
||||
})
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
const input = interiorNode.addInput('value', 'INT')
|
||||
input.widget = { name: 'value' }
|
||||
interiorNode.addWidget('number', 'value', 3, () => {})
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const widgetId = subgraphNode.inputs[0].widgetId
|
||||
if (!widgetId) throw new Error('Missing widgetId')
|
||||
|
||||
subgraphNode.configure({
|
||||
...subgraphNode.serialize(),
|
||||
widgets_values: [11],
|
||||
properties: {
|
||||
proxyWidgetErrorQuarantine: [
|
||||
{
|
||||
originalEntry: ['-1', 'seed'],
|
||||
reason: 'missingSourceNode',
|
||||
attemptedAtVersion: 1,
|
||||
hostValue: 55
|
||||
}
|
||||
]
|
||||
}
|
||||
} as ExportedSubgraphInstance)
|
||||
|
||||
expect(useWidgetValueStore().getWidget(widgetId)?.value).toBe(55)
|
||||
})
|
||||
|
||||
it('omits widget values when promoted widget state is non-serializable', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'seed', type: 'INT' }]
|
||||
})
|
||||
const interiorNode = new LGraphNode('Interior')
|
||||
const input = interiorNode.addInput('value', 'INT')
|
||||
input.widget = { name: 'value' }
|
||||
interiorNode.addWidget('number', 'value', 3, () => {})
|
||||
subgraph.add(interiorNode)
|
||||
subgraph.inputNode.slots[0].connect(interiorNode.inputs[0], interiorNode)
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const widgetId = subgraphNode.inputs[0].widgetId
|
||||
if (!widgetId) throw new Error('Missing widgetId')
|
||||
useWidgetValueStore().getWidget(widgetId)!.value = undefined
|
||||
|
||||
const serialized = subgraphNode.serialize()
|
||||
|
||||
expect(serialized.widgets_values).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode Edge Cases', () => {
|
||||
it('should handle deep nesting', () => {
|
||||
// Create a simpler deep nesting test that works with current implementation
|
||||
@@ -951,6 +1505,26 @@ describe('SubgraphNode Cleanup', () => {
|
||||
expect(abortSpy1).toHaveBeenCalledTimes(1)
|
||||
expect(abortSpy2).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('removes promoted widgets even when an input listener is absent', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'input', type: 'number' }]
|
||||
})
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
const onRemove = vi.fn()
|
||||
subgraphNode.inputs[0]._widget = fromPartial<IBaseWidget>({
|
||||
name: 'input',
|
||||
type: 'number',
|
||||
options: {},
|
||||
y: 0,
|
||||
onRemove
|
||||
})
|
||||
delete subgraphNode.inputs[0]._listenerController
|
||||
|
||||
subgraphNode.onRemoved()
|
||||
|
||||
expect(onRemove).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode duplicate input pruning (#9977)', () => {
|
||||
@@ -1076,6 +1650,49 @@ describe('Nested SubgraphNode duplicate input prevention', () => {
|
||||
expect(node.inputs).toHaveLength(2)
|
||||
expect(node.inputs.map((i) => i.name)).toEqual(['x', 'y'])
|
||||
})
|
||||
|
||||
it('rebinds duplicate serialized inputs by signature and then by name', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
{ name: 'same', type: 'STRING' },
|
||||
{ name: 'same', type: 'STRING' },
|
||||
{ name: 'loose', type: 'INT' }
|
||||
]
|
||||
})
|
||||
|
||||
const node = new SubgraphNode(
|
||||
subgraph.rootGraph,
|
||||
subgraph,
|
||||
fromPartial<ExportedSubgraphInstance>({
|
||||
id: 1,
|
||||
type: subgraph.id,
|
||||
pos: [0, 0],
|
||||
size: [200, 100],
|
||||
inputs: [
|
||||
{ name: 'same', type: 'STRING', link: null },
|
||||
{ name: 'same', type: 'STRING', link: null },
|
||||
{ name: 'loose', type: 'FLOAT', link: null },
|
||||
{ name: 'missing', type: 'BOOLEAN', link: null }
|
||||
],
|
||||
outputs: [],
|
||||
properties: {},
|
||||
flags: {},
|
||||
mode: 0,
|
||||
order: 0
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.inputs.map((input) => input.name)).toEqual([
|
||||
'same',
|
||||
'same',
|
||||
'loose'
|
||||
])
|
||||
expect(node.inputs.map((input) => input._subgraphSlot)).toEqual([
|
||||
subgraph.inputs[0],
|
||||
subgraph.inputs[1],
|
||||
subgraph.inputs[2]
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('SubgraphNode label propagation', () => {
|
||||
|
||||
245
src/lib/litegraph/src/subgraph/SubgraphOutputNode.test.ts
Normal file
245
src/lib/litegraph/src/subgraph/SubgraphOutputNode.test.ts
Normal file
@@ -0,0 +1,245 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { CanvasPointer } from '@/lib/litegraph/src/CanvasPointer'
|
||||
import { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import type {
|
||||
DefaultConnectionColors,
|
||||
INodeOutputSlot
|
||||
} from '@/lib/litegraph/src/interfaces'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import type { LinkConnector } from '@/lib/litegraph/src/canvas/LinkConnector'
|
||||
import type { NodeLike } from '@/lib/litegraph/src/types/NodeLike'
|
||||
import type { CanvasPointerEvent } from '@/lib/litegraph/src/types/events'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
|
||||
import { createTestSubgraph } from './__fixtures__/subgraphHelpers'
|
||||
|
||||
function eventAt(x: number, y: number, button = 0): CanvasPointerEvent {
|
||||
return { canvasX: x, canvasY: y, button } as CanvasPointerEvent
|
||||
}
|
||||
|
||||
function createCanvasContext() {
|
||||
return {
|
||||
getTransform: vi.fn(() => new DOMMatrix()),
|
||||
translate: vi.fn(),
|
||||
beginPath: vi.fn(),
|
||||
arc: vi.fn(),
|
||||
moveTo: vi.fn(),
|
||||
lineTo: vi.fn(),
|
||||
stroke: vi.fn(),
|
||||
setTransform: vi.fn(),
|
||||
rect: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillText: vi.fn(),
|
||||
strokeStyle: '',
|
||||
lineWidth: 1,
|
||||
font: '',
|
||||
fillStyle: '',
|
||||
textBaseline: '',
|
||||
globalAlpha: 1
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
}
|
||||
|
||||
describe('SubgraphOutputNode', () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('exposes output slots plus the empty slot and computes its anchor', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
subgraph.outputNode.configure({
|
||||
id: subgraph.outputNode.id,
|
||||
bounding: [10, 20, 100, 80],
|
||||
pinned: false
|
||||
})
|
||||
|
||||
expect(subgraph.outputNode.slots).toBe(subgraph.outputs)
|
||||
expect(subgraph.outputNode.allSlots).toEqual([
|
||||
subgraph.outputs[0],
|
||||
subgraph.outputNode.emptySlot
|
||||
])
|
||||
expect(subgraph.outputNode.slotAnchorX).toBe(24)
|
||||
})
|
||||
|
||||
it('sets link connector drag callbacks for left-clicked slots', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.outputs[0]
|
||||
slot.boundingRect.updateTo([10, 20, 100, 30])
|
||||
const pointer = {} as CanvasPointer
|
||||
const linkConnector = {
|
||||
dragNewFromSubgraphOutput: vi.fn(),
|
||||
dropLinks: vi.fn(),
|
||||
reset: vi.fn()
|
||||
} as unknown as LinkConnector
|
||||
|
||||
subgraph.outputNode.onPointerDown(eventAt(20, 25), pointer, linkConnector)
|
||||
|
||||
pointer.onDragStart?.(pointer)
|
||||
pointer.onDragEnd?.(eventAt(40, 45))
|
||||
pointer.finally?.()
|
||||
|
||||
expect(linkConnector.dragNewFromSubgraphOutput).toHaveBeenCalledWith(
|
||||
subgraph,
|
||||
subgraph.outputNode,
|
||||
slot
|
||||
)
|
||||
expect(linkConnector.dropLinks).toHaveBeenCalledWith(
|
||||
subgraph,
|
||||
expect.objectContaining({ canvasX: 40 })
|
||||
)
|
||||
expect(linkConnector.reset).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('opens the slot context menu for right-clicked slots', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.outputs[0]
|
||||
slot.boundingRect.updateTo([10, 20, 100, 30])
|
||||
const menuSpy = vi.spyOn(
|
||||
subgraph.outputNode as unknown as {
|
||||
showSlotContextMenu(slot: unknown, event: unknown): void
|
||||
},
|
||||
'showSlotContextMenu'
|
||||
)
|
||||
|
||||
subgraph.outputNode.onPointerDown(
|
||||
eventAt(20, 25, 2),
|
||||
{} as CanvasPointer,
|
||||
{} as LinkConnector
|
||||
)
|
||||
subgraph.outputNode.onPointerDown(
|
||||
eventAt(500, 500, 2),
|
||||
{} as CanvasPointer,
|
||||
{} as LinkConnector
|
||||
)
|
||||
|
||||
expect(menuSpy).toHaveBeenCalledOnce()
|
||||
expect(menuSpy).toHaveBeenCalledWith(
|
||||
slot,
|
||||
expect.objectContaining({ button: 2 })
|
||||
)
|
||||
})
|
||||
|
||||
it('renames and removes output slots through the parent subgraph', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.outputs[0]
|
||||
const renameSpy = vi.spyOn(subgraph, 'renameOutput')
|
||||
const removeSpy = vi.spyOn(subgraph, 'removeOutput')
|
||||
|
||||
subgraph.outputNode.renameSlot(slot, 'preview')
|
||||
subgraph.outputNode.removeSlot(slot)
|
||||
|
||||
expect(renameSpy).toHaveBeenCalledWith(slot, 'preview')
|
||||
expect(removeSpy).toHaveBeenCalledWith(slot)
|
||||
})
|
||||
|
||||
it('delegates connection checks and output-type connections', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
const slot = subgraph.outputs[0]
|
||||
const outputSlot = {
|
||||
index: 0,
|
||||
slot: { name: 'out', type: 'IMAGE' }
|
||||
} as unknown as { index: number; slot: INodeOutputSlot }
|
||||
const targetNode = new LGraphNode('Target')
|
||||
targetNode.id = toNodeId(99)
|
||||
vi.spyOn(targetNode, 'findOutputByType').mockReturnValue(outputSlot)
|
||||
const link = new LLink(toLinkId(1), 'IMAGE', toNodeId(1), 0, toNodeId(2), 0)
|
||||
const connectSpy = vi.spyOn(slot, 'connect').mockReturnValue(link)
|
||||
const outputNode = fromPartial<NodeLike>({
|
||||
canConnectTo: vi.fn(() => true)
|
||||
})
|
||||
|
||||
expect(
|
||||
subgraph.outputNode.canConnectTo(outputNode, slot, outputSlot.slot)
|
||||
).toBe(true)
|
||||
expect(
|
||||
subgraph.outputNode.connectByTypeOutput(0, targetNode, 'IMAGE', {
|
||||
afterRerouteId: toRerouteId(7)
|
||||
})
|
||||
).toBe(link)
|
||||
expect(connectSpy).toHaveBeenCalledWith(
|
||||
outputSlot.slot,
|
||||
targetNode,
|
||||
toRerouteId(7)
|
||||
)
|
||||
|
||||
vi.mocked(targetNode.findOutputByType).mockReturnValue(undefined)
|
||||
expect(
|
||||
subgraph.outputNode.connectByTypeOutput(0, targetNode, 'LATENT')
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('finds the first free output slot of a matching type', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [
|
||||
{ name: 'used', type: 'IMAGE' },
|
||||
{ name: 'free', type: 'IMAGE' }
|
||||
]
|
||||
})
|
||||
subgraph.outputs[0].linkIds.push(toLinkId(1))
|
||||
|
||||
expect(subgraph.outputNode.findInputByType('IMAGE')).toBe(
|
||||
subgraph.outputs[0]
|
||||
)
|
||||
expect(subgraph.outputNode.findInputByType('LATENT')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('draws the side rail and output slots', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'image', type: 'IMAGE' }]
|
||||
})
|
||||
subgraph.outputNode.configure({
|
||||
id: subgraph.outputNode.id,
|
||||
bounding: [10, 20, 100, 80],
|
||||
pinned: false
|
||||
})
|
||||
const ctx = createCanvasContext()
|
||||
const drawSlotsSpy = vi.spyOn(
|
||||
subgraph.outputNode as unknown as {
|
||||
drawSlots(
|
||||
ctx: unknown,
|
||||
colorContext: unknown,
|
||||
fromSlot: unknown,
|
||||
editorAlpha: unknown
|
||||
): void
|
||||
},
|
||||
'drawSlots'
|
||||
)
|
||||
|
||||
subgraph.outputNode.drawProtected(
|
||||
ctx,
|
||||
{
|
||||
getConnectedColor: vi.fn(() => '#fff'),
|
||||
getDisconnectedColor: vi.fn(() => '#000')
|
||||
} as unknown as DefaultConnectionColors,
|
||||
subgraph.outputs[0],
|
||||
0.5
|
||||
)
|
||||
|
||||
expect(ctx.translate).toHaveBeenCalledWith(10, 20)
|
||||
expect(ctx.beginPath).toHaveBeenCalled()
|
||||
expect(ctx.stroke).toHaveBeenCalled()
|
||||
expect(ctx.setTransform).toHaveBeenCalled()
|
||||
expect(drawSlotsSpy).toHaveBeenCalledWith(
|
||||
ctx,
|
||||
expect.objectContaining({
|
||||
getConnectedColor: expect.any(Function),
|
||||
getDisconnectedColor: expect.any(Function)
|
||||
}),
|
||||
subgraph.outputs[0],
|
||||
0.5
|
||||
)
|
||||
})
|
||||
})
|
||||
168
src/lib/litegraph/src/subgraph/SubgraphOutputSlot.test.ts
Normal file
168
src/lib/litegraph/src/subgraph/SubgraphOutputSlot.test.ts
Normal file
@@ -0,0 +1,168 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import type {
|
||||
INodeInputSlot,
|
||||
INodeOutputSlot
|
||||
} from '@/lib/litegraph/src/interfaces'
|
||||
import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
|
||||
import { createTestSubgraph } from './__fixtures__/subgraphHelpers'
|
||||
|
||||
describe('SubgraphOutput', () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('connects node outputs to subgraph outputs', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const sourceNode = new LGraphNode('Source')
|
||||
sourceNode.id = toNodeId(10)
|
||||
subgraph.add(sourceNode)
|
||||
const output = sourceNode.addOutput('image', 'IMAGE')
|
||||
const afterChangeSpy = vi.spyOn(subgraph, 'afterChange')
|
||||
const connectionSpy = vi.fn()
|
||||
sourceNode.onConnectionsChange = connectionSpy
|
||||
|
||||
const link = subgraph.outputs[0].connect(output, sourceNode, toRerouteId(5))
|
||||
|
||||
expect(link).toBeInstanceOf(LLink)
|
||||
expect(link?.origin_id).toBe(sourceNode.id)
|
||||
expect(link?.target_id).toBe(subgraph.outputNode.id)
|
||||
expect(link?.parentId).toBe(toRerouteId(5))
|
||||
expect(subgraph.outputs[0].linkIds).toEqual([link?.id])
|
||||
expect(output.links).toEqual([link?.id])
|
||||
expect(subgraph.getLink(link?.id ?? toLinkId(-1))).toBe(link)
|
||||
expect(connectionSpy).toHaveBeenCalledWith(
|
||||
NodeSlotType.OUTPUT,
|
||||
0,
|
||||
true,
|
||||
link,
|
||||
output
|
||||
)
|
||||
expect(afterChangeSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not connect incompatible or blocked node outputs', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const sourceNode = new LGraphNode('Source')
|
||||
const latentOutput = sourceNode.addOutput('latent', 'LATENT')
|
||||
|
||||
expect(
|
||||
subgraph.outputs[0].connect(latentOutput, sourceNode)
|
||||
).toBeUndefined()
|
||||
|
||||
const imageOutput = sourceNode.addOutput('image', 'IMAGE')
|
||||
sourceNode.onConnectOutput = vi.fn(() => false)
|
||||
|
||||
expect(subgraph.outputs[0].connect(imageOutput, sourceNode)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('throws when the output slot is not owned by the node', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const sourceNode = new LGraphNode('Source')
|
||||
const foreignOutput = { name: 'image', type: 'IMAGE' } as INodeOutputSlot
|
||||
|
||||
expect(() =>
|
||||
subgraph.outputs[0].connect(foreignOutput, sourceNode)
|
||||
).toThrow('Slot is not an output of the given node')
|
||||
})
|
||||
|
||||
it('disconnects existing links before accepting a replacement', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const firstNode = new LGraphNode('First')
|
||||
firstNode.id = toNodeId(10)
|
||||
subgraph.add(firstNode)
|
||||
const firstOutput = firstNode.addOutput('image', 'IMAGE')
|
||||
const firstLink = subgraph.outputs[0].connect(firstOutput, firstNode)
|
||||
const secondNode = new LGraphNode('Second')
|
||||
secondNode.id = toNodeId(11)
|
||||
subgraph.add(secondNode)
|
||||
const secondOutput = secondNode.addOutput('image', 'IMAGE')
|
||||
const beforeChangeSpy = vi.spyOn(subgraph, 'beforeChange')
|
||||
|
||||
const secondLink = subgraph.outputs[0].connect(secondOutput, secondNode)
|
||||
|
||||
expect(beforeChangeSpy).toHaveBeenCalled()
|
||||
expect(firstOutput.links).not.toContain(firstLink?.id)
|
||||
expect(subgraph.outputs[0].linkIds).toEqual([secondLink?.id])
|
||||
expect(secondOutput.links).toEqual([secondLink?.id])
|
||||
})
|
||||
|
||||
it('arranges and labels from the left edge', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const output = subgraph.outputs[0]
|
||||
|
||||
output.arrange([20, 30, 120, 40])
|
||||
|
||||
expect(Array.from(output.boundingRect)).toEqual([20, 30, 120, 40])
|
||||
expect(output.pos).toEqual([40, 50])
|
||||
expect(output.labelPos).toEqual([60, 50])
|
||||
})
|
||||
|
||||
it('validates output slots and subgraph inputs as targets', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [{ name: 'source', type: 'IMAGE' }],
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const output = subgraph.outputs[0]
|
||||
const imageOutput = fromPartial<INodeOutputSlot>({
|
||||
name: 'image',
|
||||
type: 'IMAGE',
|
||||
links: []
|
||||
})
|
||||
const latentOutput = fromPartial<INodeOutputSlot>({
|
||||
name: 'latent',
|
||||
type: 'LATENT',
|
||||
links: []
|
||||
})
|
||||
const imageInput = { name: 'image', type: 'IMAGE', link: null }
|
||||
|
||||
expect(output.isValidTarget(imageOutput)).toBe(true)
|
||||
expect(output.isValidTarget(latentOutput)).toBe(false)
|
||||
expect(output.isValidTarget(imageInput as INodeInputSlot)).toBe(false)
|
||||
expect(output.isValidTarget(subgraph.inputs[0])).toBe(true)
|
||||
})
|
||||
|
||||
it('disconnects links and notifies output nodes', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
outputs: [{ name: 'preview', type: 'IMAGE' }]
|
||||
})
|
||||
const sourceNode = new LGraphNode('Source')
|
||||
sourceNode.id = toNodeId(10)
|
||||
subgraph.add(sourceNode)
|
||||
const output = sourceNode.addOutput('image', 'IMAGE')
|
||||
const link = subgraph.outputs[0].connect(output, sourceNode)
|
||||
const removeLinkSpy = vi.spyOn(subgraph, 'removeLink')
|
||||
const connectionSpy = vi.fn()
|
||||
sourceNode.onConnectionsChange = connectionSpy
|
||||
|
||||
subgraph.outputs[0].disconnect()
|
||||
|
||||
expect(removeLinkSpy).toHaveBeenCalledWith(link?.id)
|
||||
expect(output.links).not.toContain(link?.id)
|
||||
expect(connectionSpy).toHaveBeenCalledWith(
|
||||
NodeSlotType.OUTPUT,
|
||||
0,
|
||||
false,
|
||||
link,
|
||||
subgraph.outputs[0]
|
||||
)
|
||||
expect(subgraph.outputs[0].linkIds).toEqual([])
|
||||
})
|
||||
})
|
||||
@@ -2,11 +2,18 @@ import {
|
||||
SUBGRAPH_INPUT_ID,
|
||||
SUBGRAPH_OUTPUT_ID
|
||||
} from '@/lib/litegraph/src/constants'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ExportedSubgraph } from '../types/serialisation'
|
||||
import type { LGraphState } from '@/lib/litegraph/src/LGraph'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
|
||||
import { topologicalSortSubgraphs } from './subgraphDeduplication'
|
||||
import type { ExportedSubgraph, ISerialisedNode } from '../types/serialisation'
|
||||
|
||||
import {
|
||||
deduplicateSubgraphNodeIds,
|
||||
topologicalSortSubgraphs
|
||||
} from './subgraphDeduplication'
|
||||
|
||||
function makeSubgraph(id: string, nodeTypes: string[] = []): ExportedSubgraph {
|
||||
return {
|
||||
@@ -32,6 +39,196 @@ function makeSubgraph(id: string, nodeTypes: string[] = []): ExportedSubgraph {
|
||||
} as ExportedSubgraph
|
||||
}
|
||||
|
||||
describe('deduplicateSubgraphNodeIds', () => {
|
||||
it('remaps duplicate IDs in nodes, links, promoted widgets, and root proxy widgets', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const subgraph = makeSubgraph('inner')
|
||||
subgraph.nodes = [
|
||||
{
|
||||
id: 1,
|
||||
type: 'Source',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 0,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {}
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
type: 'Target',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 1,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {}
|
||||
}
|
||||
]
|
||||
subgraph.links = [
|
||||
{
|
||||
id: 1,
|
||||
origin_id: 1,
|
||||
origin_slot: 0,
|
||||
target_id: 2,
|
||||
target_slot: 0,
|
||||
type: '*'
|
||||
}
|
||||
]
|
||||
subgraph.widgets = [
|
||||
{
|
||||
id: 1,
|
||||
name: 'text'
|
||||
}
|
||||
]
|
||||
const rootNodes: ISerialisedNode[] = [
|
||||
{
|
||||
id: 10,
|
||||
type: 'inner',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 0,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {
|
||||
proxyWidgets: [[1, 'text'], 'not-an-entry']
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 11,
|
||||
type: 'Other',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 1,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {
|
||||
proxyWidgets: [[1, 'text']]
|
||||
}
|
||||
}
|
||||
]
|
||||
const state: LGraphState = {
|
||||
lastNodeId: 2,
|
||||
lastLinkId: toLinkId(0),
|
||||
lastGroupId: 0,
|
||||
lastRerouteId: toRerouteId(0)
|
||||
}
|
||||
|
||||
const result = deduplicateSubgraphNodeIds(
|
||||
[subgraph],
|
||||
new Set([1]),
|
||||
state,
|
||||
rootNodes
|
||||
)
|
||||
|
||||
expect(result.subgraphs[0].nodes?.[0].id).toBe(3)
|
||||
expect(result.subgraphs[0].links?.[0]).toMatchObject({
|
||||
origin_id: 3,
|
||||
target_id: 2
|
||||
})
|
||||
expect(result.subgraphs[0].widgets?.[0].id).toBe(3)
|
||||
expect(result.rootNodes?.[0].properties?.proxyWidgets).toEqual([
|
||||
['3', 'text'],
|
||||
'not-an-entry'
|
||||
])
|
||||
expect(result.rootNodes?.[1].properties?.proxyWidgets).toEqual([
|
||||
[1, 'text']
|
||||
])
|
||||
expect(subgraph.nodes?.[0].id).toBe(1)
|
||||
expect(rootNodes[0].properties?.proxyWidgets).toEqual([
|
||||
[1, 'text'],
|
||||
'not-an-entry'
|
||||
])
|
||||
expect(state.lastNodeId).toBe(3)
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'LiteGraph: duplicate subgraph node ID 1 remapped to 3'
|
||||
)
|
||||
|
||||
warn.mockRestore()
|
||||
})
|
||||
|
||||
it('tracks numeric IDs without root nodes and ignores non-numeric IDs', () => {
|
||||
const subgraph = makeSubgraph('ids')
|
||||
subgraph.nodes = [
|
||||
{
|
||||
id: '9',
|
||||
type: 'NumericString',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 0,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {}
|
||||
},
|
||||
{
|
||||
id: 'alpha',
|
||||
type: 'NamedNode',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 1,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {}
|
||||
}
|
||||
]
|
||||
const state: LGraphState = {
|
||||
lastNodeId: 1,
|
||||
lastLinkId: toLinkId(0),
|
||||
lastGroupId: 0,
|
||||
lastRerouteId: toRerouteId(0)
|
||||
}
|
||||
|
||||
const result = deduplicateSubgraphNodeIds([subgraph], new Set(), state)
|
||||
|
||||
expect(result.rootNodes).toBeUndefined()
|
||||
expect(result.subgraphs[0].nodes?.map((node) => node.id)).toEqual([
|
||||
'9',
|
||||
'alpha'
|
||||
])
|
||||
expect(state.lastNodeId).toBe(9)
|
||||
})
|
||||
|
||||
it('throws when the numeric node ID space is exhausted', () => {
|
||||
const subgraph = makeSubgraph('full')
|
||||
subgraph.nodes = [
|
||||
{
|
||||
id: 1,
|
||||
type: 'Duplicate',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
order: 0,
|
||||
mode: 0,
|
||||
inputs: [],
|
||||
outputs: [],
|
||||
properties: {}
|
||||
}
|
||||
]
|
||||
const state: LGraphState = {
|
||||
lastNodeId: 100_000_000,
|
||||
lastLinkId: toLinkId(0),
|
||||
lastGroupId: 0,
|
||||
lastRerouteId: toRerouteId(0)
|
||||
}
|
||||
|
||||
expect(() =>
|
||||
deduplicateSubgraphNodeIds([subgraph], new Set([1]), state)
|
||||
).toThrow('Node ID space exhausted')
|
||||
})
|
||||
})
|
||||
|
||||
describe('topologicalSortSubgraphs', () => {
|
||||
it('returns original order when there are no dependencies', () => {
|
||||
const a = makeSubgraph('a')
|
||||
@@ -77,4 +274,11 @@ describe('topologicalSortSubgraphs', () => {
|
||||
it('returns original order for empty array', () => {
|
||||
expect(topologicalSortSubgraphs([])).toEqual([])
|
||||
})
|
||||
|
||||
it('returns original order when dependencies contain a cycle', () => {
|
||||
const a = makeSubgraph('a', ['b'])
|
||||
const b = makeSubgraph('b', ['a'])
|
||||
|
||||
expect(topologicalSortSubgraphs([a, b])).toEqual([a, b])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,26 +1,57 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
LGraph,
|
||||
LGraphGroup,
|
||||
findUsedSubgraphIds,
|
||||
getDirectSubgraphIds
|
||||
getDirectSubgraphIds,
|
||||
LGraphNode,
|
||||
LLink,
|
||||
Reroute
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { ResolvedConnection } from '@/lib/litegraph/src/LLink'
|
||||
import type { Positionable } from '@/lib/litegraph/src/interfaces'
|
||||
import type { UUID } from '@/lib/litegraph/src/litegraph'
|
||||
import type { SerialisableLLink } from '@/lib/litegraph/src/types/serialisation'
|
||||
import { SUBGRAPH_INPUT_ID } from '@/lib/litegraph/src/constants'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
|
||||
import {
|
||||
createTestSubgraph,
|
||||
createTestSubgraphNode,
|
||||
resetSubgraphFixtureState
|
||||
} from './__fixtures__/subgraphHelpers'
|
||||
import {
|
||||
getBoundaryLinks,
|
||||
groupResolvedByOutput,
|
||||
isNodeSlot,
|
||||
isSubgraphInput,
|
||||
isSubgraphOutput,
|
||||
mapSubgraphInputsAndLinks,
|
||||
mapSubgraphOutputsAndLinks,
|
||||
multiClone,
|
||||
reorderSubgraphInputs,
|
||||
splitPositionables
|
||||
} from './subgraphUtils'
|
||||
|
||||
describe('subgraphUtils', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
resetSubgraphFixtureState()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
function makeNode(title: string): LGraphNode {
|
||||
const node = new LGraphNode(title)
|
||||
node.addInput('in', 'STRING')
|
||||
node.addOutput('out', 'STRING')
|
||||
return node
|
||||
}
|
||||
|
||||
describe('getDirectSubgraphIds', () => {
|
||||
it('should return empty set for graph with no subgraph nodes', () => {
|
||||
const graph = new LGraph()
|
||||
@@ -144,5 +175,446 @@ describe('subgraphUtils', () => {
|
||||
expect(result.has(subgraph1.id)).toBe(true)
|
||||
expect(result.has(subgraph2.id)).toBe(true) // Still found, just can't recurse into it
|
||||
})
|
||||
|
||||
it('does not revisit subgraphs that were already discovered', () => {
|
||||
const rootGraph = new LGraph()
|
||||
const shared = createTestSubgraph({ name: 'Shared' })
|
||||
const nestedParent = createTestSubgraph({ name: 'Nested parent' })
|
||||
rootGraph.add(createTestSubgraphNode(shared))
|
||||
rootGraph.add(createTestSubgraphNode(nestedParent))
|
||||
nestedParent.add(createTestSubgraphNode(shared))
|
||||
|
||||
const result = findUsedSubgraphIds(
|
||||
rootGraph,
|
||||
new Map([
|
||||
[shared.id, shared],
|
||||
[nestedParent.id, nestedParent]
|
||||
])
|
||||
)
|
||||
|
||||
expect([...result]).toEqual([shared.id, nestedParent.id])
|
||||
})
|
||||
})
|
||||
|
||||
describe('splitPositionables', () => {
|
||||
it('places each known positionable type into its own set', () => {
|
||||
const subgraph = createTestSubgraph({ inputCount: 1, outputCount: 1 })
|
||||
const node = new LGraphNode('Node')
|
||||
const group = new LGraphGroup('Group')
|
||||
const reroute = new Reroute(toRerouteId(1), new LGraph())
|
||||
const unknown = fromPartial<Positionable>({ boundingRect: [0, 0, 1, 1] })
|
||||
|
||||
const result = splitPositionables([
|
||||
node,
|
||||
group,
|
||||
reroute,
|
||||
subgraph.inputNode,
|
||||
subgraph.outputNode,
|
||||
unknown
|
||||
])
|
||||
|
||||
expect(result.nodes.has(node)).toBe(true)
|
||||
expect(result.groups.has(group)).toBe(true)
|
||||
expect(result.reroutes.has(reroute)).toBe(true)
|
||||
expect(result.subgraphInputNodes.has(subgraph.inputNode)).toBe(true)
|
||||
expect(result.subgraphOutputNodes.has(subgraph.outputNode)).toBe(true)
|
||||
expect(result.unknown.has(unknown)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getBoundaryLinks', () => {
|
||||
it('classifies selected node links by internal and boundary direction', () => {
|
||||
const graph = new LGraph()
|
||||
const source = makeNode('Source')
|
||||
const selected = makeNode('Selected')
|
||||
const selectedTarget = makeNode('Selected target')
|
||||
const externalTarget = makeNode('External target')
|
||||
graph.add(source)
|
||||
graph.add(selected)
|
||||
graph.add(selectedTarget)
|
||||
graph.add(externalTarget)
|
||||
|
||||
const boundaryInput = source.connect(0, selected, 0)!
|
||||
const internal = selected.connect(0, selectedTarget, 0)!
|
||||
const boundaryOutput = selected.connect(0, externalTarget, 0)!
|
||||
|
||||
const result = getBoundaryLinks(
|
||||
graph,
|
||||
new Set([selected, selectedTarget])
|
||||
)
|
||||
|
||||
expect(result.boundaryInputLinks).toEqual([boundaryInput])
|
||||
expect(result.internalLinks).toEqual([internal])
|
||||
expect(result.boundaryOutputLinks).toEqual([boundaryOutput])
|
||||
expect(result.boundaryLinks).toEqual([])
|
||||
expect(result.boundaryFloatingLinks).toEqual([])
|
||||
})
|
||||
|
||||
it('ignores unresolved input links and warns with the missing id', () => {
|
||||
const graph = new LGraph()
|
||||
const node = makeNode('Node')
|
||||
graph.add(node)
|
||||
node.inputs[0].link = toLinkId(404)
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
|
||||
const result = getBoundaryLinks(graph, new Set([node]))
|
||||
|
||||
expect(result.internalLinks).toEqual([])
|
||||
expect(warn).toHaveBeenCalledWith('Failed to resolve link ID [404]')
|
||||
})
|
||||
|
||||
it('treats reroutes with outside participants as boundary links', () => {
|
||||
const graph = new LGraph()
|
||||
const source = makeNode('Source')
|
||||
const target = makeNode('Target')
|
||||
graph.add(source)
|
||||
graph.add(target)
|
||||
const link = source.connect(0, target, 0)!
|
||||
const reroute = new Reroute(toRerouteId(1), graph, [10, 10], undefined, [
|
||||
link.id
|
||||
])
|
||||
link.parentId = reroute.id
|
||||
graph.reroutes.set(reroute.id, reroute)
|
||||
|
||||
const result = getBoundaryLinks(graph, new Set([reroute]))
|
||||
|
||||
expect(result.boundaryLinks).toEqual([link])
|
||||
})
|
||||
|
||||
it('handles unlinked nodes, groups, subgraph-input links, and floating links', () => {
|
||||
const graph = new LGraph()
|
||||
const selected = makeNode('Selected')
|
||||
const group = new LGraphGroup('Group')
|
||||
graph.add(selected)
|
||||
const subgraphInputLink = new LLink(
|
||||
toLinkId(80),
|
||||
'STRING',
|
||||
SUBGRAPH_INPUT_ID,
|
||||
0,
|
||||
selected.id,
|
||||
0
|
||||
)
|
||||
graph.links.set(subgraphInputLink.id, subgraphInputLink)
|
||||
selected.inputs[0].link = subgraphInputLink.id
|
||||
const floatingLink = new LLink(toLinkId(81), 'STRING', 1, 0, 2, 0)
|
||||
const outsideReroute = new Reroute(toRerouteId(8), graph, [0, 0])
|
||||
floatingLink.parentId = outsideReroute.id
|
||||
graph.reroutes.set(outsideReroute.id, outsideReroute)
|
||||
selected.outputs[0]._floatingLinks = new Set([floatingLink])
|
||||
|
||||
const result = getBoundaryLinks(graph, new Set([selected, group]))
|
||||
|
||||
expect(result.boundaryInputLinks).toEqual([subgraphInputLink])
|
||||
expect(result.boundaryFloatingLinks).toEqual([floatingLink])
|
||||
expect(result.boundaryOutputLinks).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('multiClone', () => {
|
||||
it('falls back to cloned serialized data when a node type cannot be created', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const node = new LGraphNode('Fallback')
|
||||
node.type = 'missing/type'
|
||||
node.properties = { nested: { value: 1 } }
|
||||
|
||||
const result = multiClone([node])
|
||||
|
||||
expect(result).toHaveLength(1)
|
||||
expect(result[0]).toMatchObject({ type: 'missing/type' })
|
||||
expect(result[0].properties).toEqual({ nested: { value: 1 } })
|
||||
expect(result[0].properties).not.toBe(node.serialize().properties)
|
||||
expect(warn).toHaveBeenCalledWith('Failed to create node', 'missing/type')
|
||||
})
|
||||
})
|
||||
|
||||
describe('groupResolvedByOutput', () => {
|
||||
it('groups connections by subgraph input before regular output', () => {
|
||||
const subgraph = createTestSubgraph({ inputCount: 1 })
|
||||
const output = { name: 'out' }
|
||||
const first = {
|
||||
subgraphInput: subgraph.inputs[0],
|
||||
output,
|
||||
link: new LLink(toLinkId(1), 'STRING', 1, 0, 2, 0)
|
||||
} as ResolvedConnection
|
||||
const second = {
|
||||
subgraphInput: subgraph.inputs[0],
|
||||
link: new LLink(toLinkId(2), 'STRING', 1, 0, 3, 0)
|
||||
} as ResolvedConnection
|
||||
|
||||
const result = groupResolvedByOutput([first, second])
|
||||
|
||||
expect(result.get(subgraph.inputs[0])).toEqual([first, second])
|
||||
expect(result.has(output)).toBe(false)
|
||||
})
|
||||
|
||||
it('keeps unresolved output connections in separate groups', () => {
|
||||
const first = {
|
||||
link: new LLink(toLinkId(1), 'STRING', 1, 0, 2, 0)
|
||||
} as ResolvedConnection
|
||||
const second = {
|
||||
link: new LLink(toLinkId(2), 'STRING', 1, 0, 3, 0)
|
||||
} as ResolvedConnection
|
||||
|
||||
const result = groupResolvedByOutput([first, second])
|
||||
|
||||
expect(result.size).toBe(2)
|
||||
expect([...result.values()]).toEqual([[first], [second]])
|
||||
})
|
||||
})
|
||||
|
||||
describe('mapSubgraphInputsAndLinks', () => {
|
||||
it('creates unique input metadata and rewrites link origins', () => {
|
||||
const targetInput = makeNode('Target').inputs[0]
|
||||
targetInput.localized_name = 'Prompt'
|
||||
targetInput.label = 'Prompt label'
|
||||
const link = new LLink(toLinkId(1), 'STRING', 10, 0, 20, 0)
|
||||
const connection = {
|
||||
link,
|
||||
input: targetInput
|
||||
} as ResolvedConnection
|
||||
const links: SerialisableLLink[] = []
|
||||
|
||||
const inputs = mapSubgraphInputsAndLinks([connection], links, new Map())
|
||||
|
||||
expect(inputs).toHaveLength(1)
|
||||
expect(inputs[0]).toMatchObject({
|
||||
name: 'in',
|
||||
localized_name: 'Prompt',
|
||||
label: 'Prompt label',
|
||||
type: 'STRING',
|
||||
linkIds: [toLinkId(1)]
|
||||
})
|
||||
expect(links[0]).toMatchObject({
|
||||
origin_id: '-10',
|
||||
origin_slot: 0,
|
||||
target_id: 20,
|
||||
target_slot: 0
|
||||
})
|
||||
})
|
||||
|
||||
it('restores the original link parent while mapping reroutes', () => {
|
||||
const targetInput = makeNode('Target').inputs[0]
|
||||
const link = new LLink(
|
||||
toLinkId(1),
|
||||
'STRING',
|
||||
10,
|
||||
0,
|
||||
20,
|
||||
0,
|
||||
toRerouteId(2)
|
||||
)
|
||||
const first = new Reroute(
|
||||
toRerouteId(1),
|
||||
new LGraph(),
|
||||
undefined,
|
||||
toRerouteId(99)
|
||||
)
|
||||
const second = new Reroute(
|
||||
toRerouteId(2),
|
||||
new LGraph(),
|
||||
undefined,
|
||||
first.id
|
||||
)
|
||||
const links: SerialisableLLink[] = []
|
||||
|
||||
mapSubgraphInputsAndLinks(
|
||||
[{ link, input: targetInput } as ResolvedConnection],
|
||||
links,
|
||||
new Map([
|
||||
[first.id, first],
|
||||
[second.id, second]
|
||||
])
|
||||
)
|
||||
|
||||
expect(link.parentId).toBe(toRerouteId(99))
|
||||
expect(links[0].parentId).toBe(second.id)
|
||||
expect(first.parentId).toBeUndefined()
|
||||
expect(second.parentId).toBe(first.id)
|
||||
})
|
||||
|
||||
it('skips unresolved input connections and uniquifies duplicate names', () => {
|
||||
const firstInput = makeNode('First').inputs[0]
|
||||
firstInput.localized_name = 'Prompt'
|
||||
const secondInput = makeNode('Second').inputs[0]
|
||||
secondInput.localized_name = 'Prompt'
|
||||
const links: SerialisableLLink[] = []
|
||||
|
||||
const inputs = mapSubgraphInputsAndLinks(
|
||||
[
|
||||
{ link: new LLink(toLinkId(1), 'STRING', 1, 0, 2, 0) },
|
||||
{
|
||||
link: new LLink(toLinkId(2), 'STRING', 1, 0, 3, 0),
|
||||
input: firstInput
|
||||
},
|
||||
{
|
||||
link: new LLink(toLinkId(3), 'STRING', 1, 0, 4, 0),
|
||||
input: secondInput
|
||||
}
|
||||
] as ResolvedConnection[],
|
||||
links,
|
||||
new Map()
|
||||
)
|
||||
|
||||
expect(inputs.map((input) => input.name)).toEqual(['in', 'in_1'])
|
||||
expect(inputs.map((input) => input.localized_name)).toEqual([
|
||||
'Prompt',
|
||||
'Prompt_1'
|
||||
])
|
||||
expect(links.map((link) => link.id)).toEqual([toLinkId(2), toLinkId(3)])
|
||||
})
|
||||
})
|
||||
|
||||
describe('mapSubgraphOutputsAndLinks', () => {
|
||||
it('creates unique output metadata and rewrites link targets', () => {
|
||||
const output = makeNode('Source').outputs[0]
|
||||
output.type = 'IMAGE'
|
||||
output.localized_name = 'Image'
|
||||
output.label = 'Image label'
|
||||
const link = new LLink(toLinkId(1), 'IMAGE', 10, 0, 20, 0)
|
||||
const links: SerialisableLLink[] = []
|
||||
|
||||
const outputs = mapSubgraphOutputsAndLinks(
|
||||
[{ link, output } as ResolvedConnection],
|
||||
links,
|
||||
new Map()
|
||||
)
|
||||
|
||||
expect(outputs).toHaveLength(1)
|
||||
expect(outputs[0]).toMatchObject({
|
||||
name: 'out',
|
||||
localized_name: 'Image',
|
||||
label: 'Image label',
|
||||
type: 'IMAGE',
|
||||
linkIds: [toLinkId(1)]
|
||||
})
|
||||
expect(links[0]).toMatchObject({
|
||||
origin_id: 10,
|
||||
origin_slot: 0,
|
||||
target_id: '-20',
|
||||
target_slot: 0
|
||||
})
|
||||
})
|
||||
|
||||
it('skips unresolved output connections and uniquifies duplicate names', () => {
|
||||
const firstOutput = makeNode('First').outputs[0]
|
||||
firstOutput.localized_name = 'Image'
|
||||
const secondOutput = makeNode('Second').outputs[0]
|
||||
secondOutput.localized_name = 'Image'
|
||||
const links: SerialisableLLink[] = []
|
||||
|
||||
const outputs = mapSubgraphOutputsAndLinks(
|
||||
[
|
||||
{ link: new LLink(toLinkId(1), 'IMAGE', 1, 0, 2, 0) },
|
||||
{
|
||||
link: new LLink(toLinkId(2), 'IMAGE', 1, 0, 3, 0),
|
||||
output: firstOutput
|
||||
},
|
||||
{
|
||||
link: new LLink(toLinkId(3), 'IMAGE', 1, 0, 4, 0),
|
||||
output: secondOutput
|
||||
}
|
||||
] as ResolvedConnection[],
|
||||
links,
|
||||
new Map()
|
||||
)
|
||||
|
||||
expect(outputs.map((output) => output.name)).toEqual(['out', 'out_1'])
|
||||
expect(outputs.map((output) => output.localized_name)).toEqual([
|
||||
'Image',
|
||||
'Image_1'
|
||||
])
|
||||
expect(links.map((link) => link.id)).toEqual([toLinkId(2), toLinkId(3)])
|
||||
})
|
||||
})
|
||||
|
||||
describe('reorderSubgraphInputs', () => {
|
||||
it('returns when the host has no subgraph', () => {
|
||||
expect(() =>
|
||||
reorderSubgraphInputs(
|
||||
{ subgraph: null } as unknown as Parameters<
|
||||
typeof reorderSubgraphInputs
|
||||
>[0],
|
||||
[]
|
||||
)
|
||||
).not.toThrow()
|
||||
})
|
||||
|
||||
it('logs and leaves inputs unchanged for invalid permutations', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
{ name: 'first', type: 'STRING' },
|
||||
{ name: 'second', type: 'STRING' }
|
||||
]
|
||||
})
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const error = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
reorderSubgraphInputs(host, [1, 1])
|
||||
|
||||
expect(subgraph.inputs.map((input) => input.name)).toEqual([
|
||||
'first',
|
||||
'second'
|
||||
])
|
||||
expect(error).toHaveBeenCalledWith(
|
||||
'reorderSubgraphInputs: orderedIndices must be a permutation of 0..1',
|
||||
[1, 1]
|
||||
)
|
||||
})
|
||||
|
||||
it('dispatches reorder details when the input order changes', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
{ name: 'first', type: 'STRING' },
|
||||
{ name: 'second', type: 'STRING' }
|
||||
]
|
||||
})
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
const dispatch = vi.spyOn(subgraph.events, 'dispatch')
|
||||
|
||||
reorderSubgraphInputs(host, [1, 0])
|
||||
|
||||
expect(dispatch).toHaveBeenCalledWith('inputs-reordered', {
|
||||
subgraph,
|
||||
oldOrder: expect.any(Array),
|
||||
newOrder: expect.any(Array)
|
||||
})
|
||||
expect(subgraph.inputs.map((input) => input.name)).toEqual([
|
||||
'second',
|
||||
'first'
|
||||
])
|
||||
})
|
||||
|
||||
it('does not dispatch when the input order is unchanged', () => {
|
||||
const subgraph = createTestSubgraph({
|
||||
inputs: [
|
||||
{ name: 'first', type: 'STRING' },
|
||||
{ name: 'second', type: 'STRING' }
|
||||
]
|
||||
})
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
subgraph.inputs[0].linkIds.push(toLinkId(404))
|
||||
host.inputs[0].link = toLinkId(405)
|
||||
const dispatch = vi.spyOn(subgraph.events, 'dispatch')
|
||||
|
||||
reorderSubgraphInputs(host, [0, 1])
|
||||
|
||||
expect(dispatch).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('slot type guards', () => {
|
||||
it('identifies subgraph slots and node slots', () => {
|
||||
const subgraph = createTestSubgraph({ inputCount: 1, outputCount: 1 })
|
||||
const node = makeNode('Node')
|
||||
|
||||
expect(isSubgraphInput(subgraph.inputs[0])).toBe(true)
|
||||
expect(isSubgraphInput(subgraph.outputs[0])).toBe(false)
|
||||
expect(isSubgraphOutput(subgraph.outputs[0])).toBe(true)
|
||||
expect(isSubgraphOutput(node.outputs[0])).toBe(false)
|
||||
expect(isNodeSlot(node.inputs[0])).toBe(true)
|
||||
expect(isNodeSlot(node.outputs[0])).toBe(true)
|
||||
expect(isNodeSlot(null)).toBe(false)
|
||||
expect(isNodeSlot({})).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
116
src/lib/litegraph/src/utils/arrange.test.ts
Normal file
116
src/lib/litegraph/src/utils/arrange.test.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '../LGraphNode'
|
||||
import { alignNodes, distributeNodes, getBoundaryNodes } from './arrange'
|
||||
|
||||
type ArrangeNode = LGraphNode & { title: string }
|
||||
|
||||
function nodeFixture(
|
||||
title: string,
|
||||
pos: [number, number],
|
||||
size: [number, number]
|
||||
): ArrangeNode {
|
||||
const graphNode = {
|
||||
title,
|
||||
pos,
|
||||
size,
|
||||
setPos: vi.fn((x: number, y: number) => {
|
||||
graphNode.pos = [x, y]
|
||||
})
|
||||
}
|
||||
return graphNode as unknown as ArrangeNode
|
||||
}
|
||||
|
||||
describe('arrange utilities', () => {
|
||||
it('returns null when no boundary node is available', () => {
|
||||
expect(getBoundaryNodes([])).toBeNull()
|
||||
expect(getBoundaryNodes(undefined as unknown as LGraphNode[])).toBeNull()
|
||||
})
|
||||
|
||||
it('finds the furthest node in each direction', () => {
|
||||
const top = nodeFixture('top', [10, -10], [20, 20])
|
||||
const right = nodeFixture('right', [100, 0], [50, 20])
|
||||
const bottom = nodeFixture('bottom', [0, 80], [20, 60])
|
||||
const left = nodeFixture('left', [-20, 0], [10, 10])
|
||||
|
||||
expect(getBoundaryNodes([top, right, bottom, left])).toEqual({
|
||||
top,
|
||||
right,
|
||||
bottom,
|
||||
left
|
||||
})
|
||||
})
|
||||
|
||||
it('does not distribute zero or one node', () => {
|
||||
expect(distributeNodes([])).toEqual([])
|
||||
expect(distributeNodes([nodeFixture('single', [0, 0], [10, 10])])).toEqual(
|
||||
[]
|
||||
)
|
||||
})
|
||||
|
||||
it('distributes nodes horizontally by sorted position', () => {
|
||||
const first = nodeFixture('first', [0, 10], [10, 10])
|
||||
const middle = nodeFixture('middle', [30, 20], [10, 10])
|
||||
const last = nodeFixture('last', [60, 30], [20, 10])
|
||||
|
||||
const result = distributeNodes([last, first, middle], true)
|
||||
|
||||
expect(result.map(({ node: resultNode }) => resultNode.title)).toEqual([
|
||||
'first',
|
||||
'middle',
|
||||
'last'
|
||||
])
|
||||
expect(first.pos).toEqual([0, 10])
|
||||
expect(middle.pos).toEqual([30, 20])
|
||||
expect(last.pos).toEqual([60, 30])
|
||||
})
|
||||
|
||||
it('distributes nodes vertically by sorted position', () => {
|
||||
const first = nodeFixture('first', [10, 0], [10, 10])
|
||||
const middle = nodeFixture('middle', [20, 30], [10, 10])
|
||||
const last = nodeFixture('last', [30, 60], [10, 20])
|
||||
|
||||
distributeNodes([last, first, middle])
|
||||
|
||||
expect(first.pos).toEqual([10, 0])
|
||||
expect(middle.pos).toEqual([20, 30])
|
||||
expect(last.pos).toEqual([30, 60])
|
||||
})
|
||||
|
||||
it('aligns nodes to each boundary edge', () => {
|
||||
const nodesForAlign = () => [
|
||||
nodeFixture('top', [10, 0], [10, 10]),
|
||||
nodeFixture('right', [40, 10], [30, 10]),
|
||||
nodeFixture('bottom', [20, 50], [10, 30]),
|
||||
nodeFixture('left', [-10, 20], [10, 10])
|
||||
]
|
||||
|
||||
expect(
|
||||
alignNodes(nodesForAlign(), 'left').map(({ newPos }) => newPos.x)
|
||||
).toEqual([-10, -10, -10, -10])
|
||||
expect(
|
||||
alignNodes(nodesForAlign(), 'right').map(({ newPos }) => newPos.x)
|
||||
).toEqual([60, 40, 60, 60])
|
||||
expect(
|
||||
alignNodes(nodesForAlign(), 'top').map(({ newPos }) => newPos.y)
|
||||
).toEqual([0, 0, 0, 0])
|
||||
expect(
|
||||
alignNodes(nodesForAlign(), 'bottom').map(({ newPos }) => newPos.y)
|
||||
).toEqual([70, 70, 50, 70])
|
||||
})
|
||||
|
||||
it('aligns to an explicit node when provided', () => {
|
||||
const anchor = nodeFixture('anchor', [100, 200], [50, 60])
|
||||
const target = nodeFixture('target', [0, 0], [10, 20])
|
||||
|
||||
const result = alignNodes([target], 'bottom', anchor)
|
||||
|
||||
expect(result[0].newPos).toEqual({ x: 0, y: 240 })
|
||||
expect(target.setPos).toHaveBeenCalledWith(0, 240)
|
||||
})
|
||||
|
||||
it('returns no positions when alignment has no usable nodes', () => {
|
||||
expect(alignNodes([], 'left')).toEqual([])
|
||||
expect(alignNodes(undefined as unknown as LGraphNode[], 'left')).toEqual([])
|
||||
})
|
||||
})
|
||||
119
src/lib/litegraph/src/utils/collections.test.ts
Normal file
119
src/lib/litegraph/src/utils/collections.test.ts
Normal file
@@ -0,0 +1,119 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
|
||||
import {
|
||||
findFirstNode,
|
||||
findFreeSlotOfType,
|
||||
getAllNestedItems
|
||||
} from './collections'
|
||||
|
||||
import type { Positionable } from '../interfaces'
|
||||
|
||||
const graphNodeMock = vi.hoisted(() => ({
|
||||
LGraphNode: class TestLGraphNode {
|
||||
constructor(readonly title: string) {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/litegraph/src/LGraphNode', () => graphNodeMock)
|
||||
|
||||
describe('getAllNestedItems', () => {
|
||||
it('returns empty for an undefined input set', () => {
|
||||
expect(
|
||||
getAllNestedItems(undefined as unknown as ReadonlySet<Positionable>)
|
||||
).toEqual(new Set())
|
||||
})
|
||||
|
||||
it('flattens nested children while skipping pinned and repeated items', () => {
|
||||
const leaf = fromPartial<Positionable>({ pinned: false })
|
||||
const hiddenChild = fromPartial<Positionable>({ pinned: false })
|
||||
const pinned = fromPartial<Positionable>({
|
||||
pinned: true,
|
||||
children: new Set([leaf, hiddenChild])
|
||||
})
|
||||
const parent = fromPartial<Positionable>({
|
||||
pinned: false,
|
||||
children: new Set([leaf, pinned])
|
||||
})
|
||||
|
||||
const result = getAllNestedItems(new Set([parent, leaf]))
|
||||
|
||||
expect(result).toEqual(new Set([parent, leaf]))
|
||||
expect(result.has(hiddenChild)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('findFirstNode', () => {
|
||||
it('returns the first graph node from a mixed collection', () => {
|
||||
const node = new LGraphNode('node')
|
||||
|
||||
expect(findFirstNode([{ pinned: false } as Positionable, node])).toBe(node)
|
||||
})
|
||||
|
||||
it('returns undefined when no graph node is present', () => {
|
||||
expect(findFirstNode([{ pinned: false } as Positionable])).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('findFreeSlotOfType', () => {
|
||||
interface Slot {
|
||||
type: string
|
||||
links: number[]
|
||||
}
|
||||
|
||||
const hasNoLinks = (slot: Slot) => slot.links.length === 0
|
||||
|
||||
it('returns undefined for an empty slot list', () => {
|
||||
expect(findFreeSlotOfType([], 'IMAGE', hasNoLinks)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('prefers the first free exact type match', () => {
|
||||
const slots = [
|
||||
{ type: 'IMAGE', links: [1] },
|
||||
{ type: 'IMAGE', links: [] }
|
||||
]
|
||||
|
||||
expect(findFreeSlotOfType(slots, 'IMAGE', hasNoLinks)).toEqual({
|
||||
index: 1,
|
||||
slot: slots[1]
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to a free wildcard before an occupied exact slot', () => {
|
||||
const slots = [
|
||||
{ type: 'IMAGE', links: [1] },
|
||||
{ type: '*', links: [] }
|
||||
]
|
||||
|
||||
expect(findFreeSlotOfType(slots, 'IMAGE', hasNoLinks)).toEqual({
|
||||
index: 1,
|
||||
slot: slots[1]
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to an occupied exact slot before an occupied wildcard', () => {
|
||||
const slots = [
|
||||
{ type: '*', links: [1] },
|
||||
{ type: 'IMAGE', links: [2] }
|
||||
]
|
||||
|
||||
expect(findFreeSlotOfType(slots, 'IMAGE', hasNoLinks)).toEqual({
|
||||
index: 1,
|
||||
slot: slots[1]
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to an occupied wildcard when no exact slot matches', () => {
|
||||
const slots = [
|
||||
{ type: 'LATENT', links: [1] },
|
||||
{ type: '*', links: [2] }
|
||||
]
|
||||
|
||||
expect(findFreeSlotOfType(slots, 'IMAGE', hasNoLinks)).toEqual({
|
||||
index: 1,
|
||||
slot: slots[1]
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -25,6 +25,6 @@ function handleClose() {
|
||||
}
|
||||
|
||||
function handleSubscribe() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
|
||||
// Shows loading affordances
|
||||
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
|
||||
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'team_700',
|
||||
'yearly'
|
||||
'yearly',
|
||||
{ paymentIntentSource: 'deep_link' }
|
||||
)
|
||||
// Team never goes through the personal checkout path
|
||||
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()
|
||||
|
||||
@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
return
|
||||
}
|
||||
isTeamCheckout.value = true
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle)
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle, {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
if (isActiveSubscription.value) {
|
||||
await accessBillingPortal(undefined, false)
|
||||
} else {
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
}
|
||||
}, reportError)
|
||||
|
||||
|
||||
@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
|
||||
})
|
||||
|
||||
function handleAddCredits() {
|
||||
telemetry?.trackAddApiCreditButtonClicked()
|
||||
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
function handleUpgradeToAddCredits() {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
}
|
||||
|
||||
async function handleWindowFocus() {
|
||||
|
||||
@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import FreeTierDialogContent from './FreeTierDialogContent.vue'
|
||||
|
||||
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
|
||||
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
function renderComponent() {
|
||||
function renderComponent(props?: { reason?: PaymentIntentSource }) {
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
@@ -23,6 +25,7 @@ function renderComponent() {
|
||||
})
|
||||
|
||||
return render(FreeTierDialogContent, {
|
||||
props,
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
|
||||
renderComponent()
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps the generic copy for intent reasons outside the credits variants', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'subscribe_to_run' })
|
||||
expect(
|
||||
screen.getByText('Your credits refresh on Jul 15, 2026.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('swaps to the out-of-credits copy without the refresh line', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="!reason || reason === 'subscription_required'"
|
||||
v-if="!isCreditsBlockedVariant"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -65,10 +65,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="
|
||||
(!reason || reason === 'subscription_required') &&
|
||||
formattedRenewalDate
|
||||
"
|
||||
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -88,7 +85,7 @@
|
||||
@click="$emit('upgrade')"
|
||||
>
|
||||
{{
|
||||
reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
isCreditsBlockedVariant
|
||||
? $t('subscription.freeTier.upgradeCta')
|
||||
: $t('subscription.freeTier.subscribeCta')
|
||||
}}
|
||||
@@ -103,12 +100,12 @@ import { computed } from 'vue'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
|
||||
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
|
||||
defineProps<{
|
||||
reason?: SubscriptionDialogReason
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
defineEmits<{
|
||||
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
|
||||
})
|
||||
|
||||
const freeTierCredits = computed(() => getTierCredits('free'))
|
||||
|
||||
// Only these two variants replace the generic free-tier copy; any other
|
||||
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
|
||||
const isCreditsBlockedVariant = computed(
|
||||
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
)
|
||||
</script>
|
||||
|
||||
@@ -261,6 +261,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
|
||||
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use the latest userId value when it changes after mount', async () => {
|
||||
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
})
|
||||
|
||||
@@ -277,13 +277,19 @@ import type {
|
||||
TierKey,
|
||||
TierPricing
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import {
|
||||
recordPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
|
||||
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
@@ -321,6 +327,10 @@ interface PricingTierConfig {
|
||||
isPopular?: boolean
|
||||
}
|
||||
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
chooseTeamWorkspace: []
|
||||
}>()
|
||||
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
} as const
|
||||
const previousPlan = currentPlanDescriptor.value
|
||||
const checkoutAttribution = await getCheckoutAttributionForCloud()
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
})
|
||||
}
|
||||
const beginCheckoutMetadata = userId.value
|
||||
? {
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change' as const,
|
||||
...(reason ? { payment_intent_source: reason } : {}),
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
}
|
||||
: null
|
||||
// Pass the target tier to create a deep link to subscription update confirmation
|
||||
const checkoutTier = getCheckoutTier(
|
||||
targetPlan.tierKey,
|
||||
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
|
||||
if (downgrade) {
|
||||
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
|
||||
await accessBillingPortal()
|
||||
const didOpenPortal = await accessBillingPortal()
|
||||
if (didOpenPortal && beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
|
||||
}
|
||||
} else {
|
||||
const didOpenPortal = await accessBillingPortal(checkoutTier)
|
||||
if (!didOpenPortal) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
payment_intent_source: reason,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
|
||||
...(previousPlan
|
||||
? { previous_cycle: previousPlan.billingCycle }
|
||||
: {})
|
||||
})
|
||||
if (beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
beginCheckoutMetadata,
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await performSubscriptionCheckout(
|
||||
tierKey,
|
||||
currentBillingCycle.value,
|
||||
true
|
||||
)
|
||||
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
|
||||
paymentIntentSource: reason
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
|
||||
@@ -56,7 +56,7 @@ const handleSubscribe = () => {
|
||||
current_tier: tier.value?.toLowerCase()
|
||||
})
|
||||
isAwaitingStripeSubscription.value = true
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
|
||||
@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
|
||||
trackRunButton({ subscribe_to_run: true })
|
||||
}
|
||||
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -48,7 +48,9 @@
|
||||
v-if="isActiveSubscription"
|
||||
variant="primary"
|
||||
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
|
||||
@click="showSubscriptionDialog"
|
||||
@click="
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
"
|
||||
>
|
||||
{{ $t('subscription.upgradePlan') }}
|
||||
</Button>
|
||||
|
||||
@@ -33,7 +33,11 @@
|
||||
</i18n-t>
|
||||
</div>
|
||||
|
||||
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
|
||||
<PricingTable
|
||||
:reason
|
||||
class="flex-1"
|
||||
@choose-team-workspace="handleChooseTeam"
|
||||
/>
|
||||
|
||||
<!-- Contact and Enterprise Links -->
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
const { onClose, reason, onChooseTeam } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
onChooseTeam?: () => void
|
||||
}>()
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
|
||||
)
|
||||
return
|
||||
case 'subscription':
|
||||
void dialogService.showSubscriptionRequiredDialog()
|
||||
void dialogService.showSubscriptionRequiredDialog({
|
||||
reason: 'subscription_required'
|
||||
})
|
||||
return
|
||||
case 'credits':
|
||||
void dialogService.showTopUpCreditsDialog({
|
||||
|
||||
@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
describe('usePricingTableUrlLoader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
reason: 'deep_link',
|
||||
planMode: undefined
|
||||
})
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
reason: 'deep_link'
|
||||
})
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
})
|
||||
|
||||
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('denies, strips, and clears together when the user is not eligible', async () => {
|
||||
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({
|
||||
query: { other: 'param' }
|
||||
})
|
||||
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
)
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
|
||||
'pricing'
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
mergePreservedQueryIntoQuery
|
||||
} from '@/platform/navigation/preservedQueryManager'
|
||||
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
|
||||
const planMode =
|
||||
param === 'team' || param === 'personal' ? param : undefined
|
||||
|
||||
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
|
||||
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
|
||||
})
|
||||
}, reportError)
|
||||
|
||||
const showSubscriptionDialog = (options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) => {
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: subscriptionTier.value?.toLowerCase(),
|
||||
reason: options?.reason
|
||||
})
|
||||
|
||||
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
|
||||
void showSubscriptionRequiredDialog(options)
|
||||
}
|
||||
|
||||
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
|
||||
await fetchSubscriptionStatus()
|
||||
|
||||
if (!isSubscribedOrIsNotCloud.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
|
||||
}))
|
||||
|
||||
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
|
||||
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
|
||||
const {
|
||||
mockIsCloud,
|
||||
mockTrackHelpResourceClicked,
|
||||
mockTrackAddApiCreditButtonClicked
|
||||
} = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockTrackHelpResourceClicked: vi.fn()
|
||||
mockTrackHelpResourceClicked: vi.fn(),
|
||||
mockTrackAddApiCreditButtonClicked: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () =>
|
||||
mockIsCloud.value
|
||||
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
|
||||
? {
|
||||
trackHelpResourceClicked: mockTrackHelpResourceClicked,
|
||||
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
|
||||
}
|
||||
: null
|
||||
}))
|
||||
|
||||
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
|
||||
const { handleAddApiCredits } = useSubscriptionActions()
|
||||
handleAddApiCredits()
|
||||
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
|
||||
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
|
||||
})
|
||||
|
||||
const handleAddApiCredits = () => {
|
||||
telemetry?.trackAddApiCreditButtonClicked({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockShowLayoutDialog = vi.fn()
|
||||
const mockShowTeamWorkspacesDialog = vi.fn()
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
|
||||
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
|
||||
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
|
||||
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isFreeTier: mockIsFreeTier,
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan,
|
||||
tier: mockTier
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
|
||||
useWorkspaceUI: () => ({
|
||||
permissions: {
|
||||
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
|
||||
mockIsCloud.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
mockIsFreeTier.value = false
|
||||
mockTier.value = 'FREE'
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
mockCanManageSubscription.value = true
|
||||
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
|
||||
const props = mockShowLayoutDialog.mock.calls[0][0].props
|
||||
expect(props.initialPlanMode).toBe('team')
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the caller reason and current tier', () => {
|
||||
mockTier.value = 'STANDARD'
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
current_tier: 'standard',
|
||||
reason: 'upgrade_to_add_credits'
|
||||
})
|
||||
})
|
||||
|
||||
it('tracks modal_opened on the workspace (unified) path too', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'subscribe_to_run' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not track modal_opened for the inactive member dialog', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockCanManageSubscription.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not track on non-cloud', () => {
|
||||
mockIsCloud.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('show', () => {
|
||||
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
|
||||
expect.objectContaining({ key: 'subscription-required' })
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the reason for the free-tier dialog', () => {
|
||||
mockIsFreeTier.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { show } = useSubscriptionDialog()
|
||||
|
||||
show({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'out_of_credits' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startTeamWorkspaceUpgradeFlow', () => {
|
||||
|
||||
@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
|
||||
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
|
||||
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
|
||||
|
||||
export type SubscriptionDialogReason =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
|
||||
interface SubscriptionDialogOptions {
|
||||
reason?: SubscriptionDialogReason
|
||||
export interface SubscriptionDialogOptions {
|
||||
reason?: PaymentIntentSource
|
||||
/**
|
||||
* Forces the unified pricing dialog to open on a specific plan tab,
|
||||
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
|
||||
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
|
||||
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
|
||||
}
|
||||
|
||||
// Fired here — the choke point every paywall/pricing dialog variant passes
|
||||
// through — so both the legacy and workspace billing paths emit it.
|
||||
function trackModalOpened(reason?: PaymentIntentSource) {
|
||||
// Resolved lazily to avoid the useBillingContext import cycle (see below).
|
||||
const { tier } = useBillingContext()
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: tier.value?.toLowerCase(),
|
||||
reason
|
||||
})
|
||||
}
|
||||
|
||||
function showPricingTable(options?: SubscriptionDialogOptions) {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
|
||||
return
|
||||
}
|
||||
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
// Shared dialog shell styling for both variants.
|
||||
const dialogComponentProps = {
|
||||
style: 'width: min(1328px, 95vw); max-height: 958px;',
|
||||
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
|
||||
// (not at composable setup) to avoid the useBillingContext import cycle.
|
||||
const { isFreeTier } = useBillingContext()
|
||||
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
const component = defineAsyncComponent(
|
||||
() =>
|
||||
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
|
||||
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
|
||||
sessionStorage.removeItem(RESUME_PRICING_KEY)
|
||||
|
||||
if (!workspaceStore.isInPersonalWorkspace) {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'team_upgrade_resume' })
|
||||
}
|
||||
} catch {
|
||||
// sessionStorage may be unavailable
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
clearPendingSubscriptionCheckoutAttempt,
|
||||
consumePendingSubscriptionCheckoutSuccess,
|
||||
recordPendingSubscriptionCheckoutAttempt
|
||||
} from './subscriptionCheckoutTracker'
|
||||
|
||||
const activeProStatus = {
|
||||
is_active: true,
|
||||
subscription_tier: 'PRO',
|
||||
subscription_duration: 'MONTHLY'
|
||||
} as const
|
||||
|
||||
describe('subscriptionCheckoutTracker', () => {
|
||||
beforeEach(() => {
|
||||
clearPendingSubscriptionCheckoutAttempt()
|
||||
})
|
||||
|
||||
it('round-trips payment_intent_source from attempt to success metadata', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).toMatchObject({
|
||||
tier: 'pro',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('omits payment_intent_source when the attempt had none', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).not.toBeNull()
|
||||
expect(metadata).not.toHaveProperty('payment_intent_source')
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,12 @@ import type {
|
||||
TierKey
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
BeginCheckoutMetadata,
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType,
|
||||
SubscriptionSuccessMetadata
|
||||
} from '@/platform/telemetry/types'
|
||||
|
||||
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
|
||||
const VALID_TIER_KEYS = new Set<TierKey>([
|
||||
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
|
||||
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
|
||||
'comfy:subscription-checkout-attempt-changed'
|
||||
|
||||
type CheckoutType = 'new' | 'change'
|
||||
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
|
||||
|
||||
interface SubscriptionStatusSnapshot {
|
||||
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
|
||||
subscription_duration?: SubscriptionDuration | null
|
||||
}
|
||||
|
||||
interface PendingSubscriptionCheckoutAttempt {
|
||||
export interface PendingSubscriptionCheckoutAttempt {
|
||||
attempt_id: string
|
||||
started_at_ms: number
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface RecordPendingSubscriptionCheckoutAttemptInput {
|
||||
interface PendingSubscriptionCheckoutAttemptInput {
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
const dispatchPendingCheckoutChangeEvent = () => {
|
||||
@@ -168,6 +174,9 @@ const normalizeAttempt = (
|
||||
...(candidate.previous_cycle === 'monthly' ||
|
||||
candidate.previous_cycle === 'yearly'
|
||||
? { previous_cycle: candidate.previous_cycle }
|
||||
: {}),
|
||||
...(typeof candidate.payment_intent_source === 'string'
|
||||
? { payment_intent_source: candidate.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
|
||||
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
|
||||
getPendingSubscriptionCheckoutAttempt() !== null
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: RecordPendingSubscriptionCheckoutAttemptInput
|
||||
export const createPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
const attempt: PendingSubscriptionCheckoutAttempt = {
|
||||
return {
|
||||
attempt_id: createAttemptId(),
|
||||
started_at_ms: Date.now(),
|
||||
tier: input.tier,
|
||||
cycle: input.cycle,
|
||||
checkout_type: input.checkout_type,
|
||||
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
|
||||
...(input.payment_intent_source
|
||||
? { payment_intent_source: input.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
|
||||
export const persistPendingSubscriptionCheckoutAttempt = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
if (!storage) {
|
||||
return attempt
|
||||
}
|
||||
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
return attempt
|
||||
}
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt =>
|
||||
persistPendingSubscriptionCheckoutAttempt(
|
||||
createPendingSubscriptionCheckoutAttempt(input)
|
||||
)
|
||||
|
||||
export const withPendingCheckoutAttemptId = (
|
||||
metadata: BeginCheckoutMetadata,
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): BeginCheckoutMetadata => ({
|
||||
...metadata,
|
||||
checkout_attempt_id: attempt.attempt_id
|
||||
})
|
||||
|
||||
const didAttemptSucceed = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt,
|
||||
status: SubscriptionStatusSnapshot
|
||||
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
|
||||
cycle: attempt.cycle,
|
||||
checkout_type: attempt.checkout_type,
|
||||
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
|
||||
...(attempt.payment_intent_source
|
||||
? { payment_intent_source: attempt.payment_intent_source }
|
||||
: {}),
|
||||
value,
|
||||
currency: 'USD',
|
||||
ecommerce: {
|
||||
|
||||
@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'yearly', true)
|
||||
await performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
ga_client_id: 'ga-client-id',
|
||||
ga_session_id: 'ga-session-id',
|
||||
ga_session_number: 'ga-session-number',
|
||||
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
|
||||
gbraid: 'gbraid-456',
|
||||
wbraid: 'wbraid-789'
|
||||
})
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
JSON.parse(storedAttempt).attempt_id
|
||||
)
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'/customers/cloud-subscription-checkout/pro-yearly'
|
||||
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[SubscriptionCheckout] Failed to collect checkout attribution',
|
||||
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
.spyOn(window, 'open')
|
||||
.mockImplementation(() => window as unknown as Window)
|
||||
|
||||
vi.mocked(global.fetch).mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', {
|
||||
paymentIntentSource: 'out_of_credits'
|
||||
})
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
|
||||
)
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
const pendingAttempt = JSON.parse(storedAttempt)
|
||||
expect(pendingAttempt).toMatchObject({
|
||||
payment_intent_source: 'out_of_credits'
|
||||
})
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
pendingAttempt.attempt_id
|
||||
)
|
||||
openSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('uses the latest userId when it changes after checkout starts', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
mockUserId.value = 'user-late'
|
||||
authHeader.resolve({ Authorization: 'Bearer test-token' })
|
||||
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-late',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
|
||||
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
|
||||
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
const storedAttempt = window.localStorage.getItem(
|
||||
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
|
||||
)
|
||||
expect(storedAttempt).toBeNull()
|
||||
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { getComfyApiBaseUrl } from '@/config/comfyApi'
|
||||
import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import {
|
||||
createPendingSubscriptionCheckoutAttempt,
|
||||
persistPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
type CheckoutTier = TierKey | `${TierKey}-yearly`
|
||||
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
|
||||
return getCheckoutAttribution()
|
||||
}
|
||||
|
||||
interface PerformSubscriptionCheckoutOptions {
|
||||
openInNewTab?: boolean
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Core subscription checkout logic shared between PricingTable and
|
||||
* SubscriptionRedirectView. Handles:
|
||||
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
|
||||
export async function performSubscriptionCheckout(
|
||||
tierKey: TierKey,
|
||||
currentBillingCycle: BillingCycle,
|
||||
openInNewTab: boolean = true
|
||||
options: PerformSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
const { openInNewTab = true, paymentIntentSource } = options
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { userId } = storeToRefs(authStore)
|
||||
const telemetry = useTelemetry()
|
||||
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
|
||||
const data = await response.json()
|
||||
|
||||
if (data.checkout_url) {
|
||||
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: paymentIntentSource
|
||||
})
|
||||
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...checkoutAttribution
|
||||
})
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
{
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {}),
|
||||
...checkoutAttribution
|
||||
},
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if (openInNewTab) {
|
||||
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
|
||||
if (!checkoutWindow) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
} else {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
globalThis.location.href = data.checkout_url
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn()
|
||||
}))
|
||||
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
|
||||
vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
|
||||
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
|
||||
workspaceApi: { subscribe: mockSubscribe }
|
||||
}))
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
|
||||
}))
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
|
||||
|
||||
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
billing_op_id: 'op_1'
|
||||
})
|
||||
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly', {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
|
||||
returnUrl: 'https://app.test/payment/success',
|
||||
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
teamCreditStopId: 'team_700'
|
||||
})
|
||||
expect(assignedHref).toBe('https://stripe.test/pay')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'team',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op_1',
|
||||
payment_intent_source: 'deep_link'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
|
||||
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
expect(assignedHref).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not track begin_checkout when subscribe fails', async () => {
|
||||
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
|
||||
|
||||
await expect(
|
||||
performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
).rejects.toThrow('subscribe failed')
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does nothing off cloud', async () => {
|
||||
mockIsCloud.value = false
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
|
||||
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
interface PerformTeamSubscriptionCheckoutOptions {
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
|
||||
* link: subscribes to the per-credit Team plan at the chosen slider stop and
|
||||
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
|
||||
*/
|
||||
export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId: string,
|
||||
billingCycle: BillingCycle
|
||||
billingCycle: BillingCycle,
|
||||
options: PerformTeamSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId
|
||||
})
|
||||
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType: 'new',
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource: options.paymentIntentSource
|
||||
})
|
||||
|
||||
if (response.status === 'needs_payment_method') {
|
||||
// A needs_payment_method response without a URL is unusable: surface it to
|
||||
// the caller's error handling rather than silently dropping the user home
|
||||
|
||||
@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
|
||||
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
})
|
||||
|
||||
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
|
||||
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
|
||||
const b: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const metadata = {
|
||||
user_id: 'user-1',
|
||||
tier: 'pro' as const,
|
||||
cycle: 'monthly' as const,
|
||||
checkout_type: 'new' as const,
|
||||
payment_intent_source: 'subscribe_to_run' as const
|
||||
}
|
||||
registry.trackBeginCheckout(metadata)
|
||||
|
||||
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
|
||||
})
|
||||
|
||||
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
|
||||
const provider: TelemetryProvider = {
|
||||
trackAddApiCreditButtonClicked: vi.fn()
|
||||
}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(provider)
|
||||
|
||||
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(
|
||||
provider.trackAddApiCreditButtonClicked
|
||||
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
|
||||
})
|
||||
|
||||
it('skips providers that do not implement trackSearchQuery', () => {
|
||||
const empty: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
|
||||
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackAddApiCreditButtonClicked?.(metadata)
|
||||
)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('captures begin_checkout with intent metadata', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackBeginCheckout({
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.BEGIN_CHECKOUT,
|
||||
{
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('captures add-credit clicks with their source', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'credits_panel' }
|
||||
)
|
||||
})
|
||||
|
||||
it('captures share attribution events', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
EnterLinearMetadata,
|
||||
ShareFlowMetadata,
|
||||
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
|
||||
this.trackEvent(eventName, metadata)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
|
||||
}
|
||||
|
||||
trackMonthlySubscriptionSucceeded(
|
||||
|
||||
@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards add-credit clicks with their source', () => {
|
||||
new HostTelemetrySink().trackAddApiCreditButtonClicked({
|
||||
source: 'avatar_menu'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'avatar_menu' }
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when the host bridge is absent', () => {
|
||||
delete window.__comfyDesktop2
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
|
||||
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -12,12 +12,29 @@
|
||||
* 3. Check dist/assets/*.js files contain no tracking code
|
||||
*/
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
import type { AppMode } from '@/utils/appMode'
|
||||
|
||||
export type PaymentIntentSource =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
| 'subscribe_to_run'
|
||||
| 'subscribe_now_button'
|
||||
| 'upgrade_to_add_credits'
|
||||
| 'settings_billing_panel'
|
||||
| 'avatar_menu_plans'
|
||||
| 'team_members_panel'
|
||||
| 'invite_member_upsell'
|
||||
| 'upload_model_upgrade'
|
||||
| 'team_upgrade_resume'
|
||||
|
||||
export type SubscriptionCheckoutType = 'new' | 'change'
|
||||
export type SubscriptionCheckoutTier = TierKey | 'team'
|
||||
|
||||
/**
|
||||
* Authentication metadata for sign-up tracking
|
||||
*/
|
||||
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
|
||||
|
||||
export interface SubscriptionMetadata {
|
||||
current_tier?: string
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export interface AddCreditsClickMetadata {
|
||||
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
|
||||
}
|
||||
|
||||
export interface BeginCheckoutMetadata
|
||||
extends Record<string, unknown>, CheckoutAttributionMetadata {
|
||||
user_id: string
|
||||
tier: TierKey
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
checkout_attempt_id?: string
|
||||
billing_op_id?: string
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface EcommerceItemMetadata {
|
||||
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
|
||||
checkout_attempt_id: string
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
value: number
|
||||
currency: string
|
||||
ecommerce: EcommerceMetadata
|
||||
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
|
||||
metadata?: SubscriptionSuccessMetadata
|
||||
): void
|
||||
trackMonthlySubscriptionCancelled?(): void
|
||||
trackAddApiCreditButtonClicked?(): void
|
||||
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
|
||||
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
|
||||
trackApiCreditTopupSucceeded?(): void
|
||||
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void
|
||||
|
||||
@@ -18,12 +18,14 @@ vi.mock('@/scripts/utils', () => ({
|
||||
downloadBlob: mockDownloadBlob
|
||||
}))
|
||||
|
||||
const mockCreateTemporary = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: () => ({ createTemporary: vi.fn() })
|
||||
useWorkflowStore: () => ({ createTemporary: mockCreateTemporary })
|
||||
}))
|
||||
|
||||
const mockOpenWorkflow = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/workflow/core/services/workflowService', () => ({
|
||||
useWorkflowService: () => ({ openWorkflow: vi.fn() })
|
||||
useWorkflowService: () => ({ openWorkflow: mockOpenWorkflow })
|
||||
}))
|
||||
|
||||
const minimalWorkflow: ComfyWorkflowJSON = {
|
||||
@@ -37,6 +39,8 @@ const minimalWorkflow: ComfyWorkflowJSON = {
|
||||
describe('workflowActionsService.exportWorkflowAction', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCreateTemporary.mockReturnValue({ path: 'temporary.json' })
|
||||
mockOpenWorkflow.mockResolvedValue({ path: 'temporary.json' })
|
||||
})
|
||||
|
||||
it('returns { cancelled: true } when the user dismisses the filename prompt', async () => {
|
||||
@@ -89,4 +93,73 @@ describe('workflowActionsService.exportWorkflowAction', () => {
|
||||
})
|
||||
expect(mockDownloadBlob).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns a fallback error when export throws a non-error value', async () => {
|
||||
mockGetSetting.mockReturnValue(false)
|
||||
mockDownloadBlob.mockImplementationOnce(() => {
|
||||
throw 'download failed'
|
||||
})
|
||||
const { exportWorkflowAction } = useWorkflowActionsService()
|
||||
|
||||
const result = await exportWorkflowAction(minimalWorkflow, 'wf.json')
|
||||
|
||||
expect(result).toEqual({
|
||||
success: false,
|
||||
error: 'Failed to export workflow'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('workflowActionsService.openWorkflowAction', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCreateTemporary.mockReturnValue({ path: 'temporary.json' })
|
||||
mockOpenWorkflow.mockResolvedValue({ path: 'temporary.json' })
|
||||
})
|
||||
|
||||
it('opens a temporary workflow and returns success', async () => {
|
||||
const { openWorkflowAction } = useWorkflowActionsService()
|
||||
|
||||
const result = await openWorkflowAction(minimalWorkflow, 'wf.json')
|
||||
|
||||
expect(result).toEqual({ success: true })
|
||||
expect(mockCreateTemporary).toHaveBeenCalledWith('wf.json', minimalWorkflow)
|
||||
expect(mockOpenWorkflow).toHaveBeenCalledWith({ path: 'temporary.json' })
|
||||
})
|
||||
|
||||
it('returns the no-workflow error when opening null', async () => {
|
||||
const { openWorkflowAction } = useWorkflowActionsService()
|
||||
|
||||
const result = await openWorkflowAction(null, 'wf.json')
|
||||
|
||||
expect(result).toEqual({
|
||||
success: false,
|
||||
error: 'No workflow data available'
|
||||
})
|
||||
expect(mockCreateTemporary).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns thrown error messages from failed opens', async () => {
|
||||
mockOpenWorkflow.mockRejectedValueOnce(new Error('Open failed'))
|
||||
const { openWorkflowAction } = useWorkflowActionsService()
|
||||
|
||||
const result = await openWorkflowAction(minimalWorkflow, 'wf.json')
|
||||
|
||||
expect(result).toEqual({
|
||||
success: false,
|
||||
error: 'Open failed'
|
||||
})
|
||||
})
|
||||
|
||||
it('returns a fallback error when opening throws a non-error value', async () => {
|
||||
mockOpenWorkflow.mockRejectedValueOnce('Open failed')
|
||||
const { openWorkflowAction } = useWorkflowActionsService()
|
||||
|
||||
const result = await openWorkflowAction(minimalWorkflow, 'wf.json')
|
||||
|
||||
expect(result).toEqual({
|
||||
success: false,
|
||||
error: 'Failed to open workflow'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -17,9 +17,13 @@ import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { useMissingMediaStore } from '@/platform/missingMedia/missingMediaStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import type { ChangeTracker } from '@/scripts/changeTracker'
|
||||
import { useAppMode } from '@/composables/useAppMode'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import { createMockChangeTracker } from '@/utils/__tests__/litegraphTestUtils'
|
||||
import {
|
||||
createMockCanvasRenderingContext2D,
|
||||
createMockChangeTracker
|
||||
} from '@/utils/__tests__/litegraphTestUtils'
|
||||
import type { AppMode } from '@/utils/appMode'
|
||||
import { t } from '@/i18n'
|
||||
|
||||
@@ -61,10 +65,13 @@ function makeWorkflowData(
|
||||
}
|
||||
}
|
||||
|
||||
const { mockConfirm, mockTrackWorkflowSaved } = vi.hoisted(() => ({
|
||||
mockConfirm: vi.fn(),
|
||||
mockTrackWorkflowSaved: vi.fn()
|
||||
}))
|
||||
const { mockConfirm, mockPrompt, mockTrackWorkflowSaved, mockDownloadBlob } =
|
||||
vi.hoisted(() => ({
|
||||
mockConfirm: vi.fn(),
|
||||
mockPrompt: vi.fn(),
|
||||
mockTrackWorkflowSaved: vi.fn(),
|
||||
mockDownloadBlob: vi.fn()
|
||||
}))
|
||||
|
||||
const draftStoreMocks = vi.hoisted(() => ({
|
||||
saveDraft: vi.fn(() => true),
|
||||
@@ -75,16 +82,21 @@ const draftStoreMocks = vi.hoisted(() => ({
|
||||
|
||||
vi.mock('@/services/dialogService', () => ({
|
||||
useDialogService: () => ({
|
||||
prompt: vi.fn(),
|
||||
prompt: mockPrompt,
|
||||
confirm: mockConfirm
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/base/common/downloadUtil', () => ({
|
||||
downloadBlob: mockDownloadBlob
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
canvas: { ds: { offset: [0, 0], scale: 1 } },
|
||||
rootGraph: { serialize: vi.fn(() => ({})), extra: {} },
|
||||
loadGraphData: vi.fn()
|
||||
loadGraphData: vi.fn(),
|
||||
graphToPrompt: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -166,6 +178,74 @@ describe('useWorkflowService', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
draftStoreMocks.saveDraft.mockReturnValue(true)
|
||||
mockPrompt.mockResolvedValue(null)
|
||||
})
|
||||
|
||||
describe('exportWorkflow', () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(app.graphToPrompt).mockResolvedValue({
|
||||
workflow: makeWorkflowData(),
|
||||
output: { prompt: true }
|
||||
} as never)
|
||||
})
|
||||
|
||||
it('uses the active workflow filename and adds view restore data', async () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
workflowStore.activeWorkflow = createModeTestWorkflow({
|
||||
path: 'workflows/current.json'
|
||||
})
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => {
|
||||
if (key === 'Comfy.EnableWorkflowViewRestore') return true
|
||||
return false
|
||||
}
|
||||
)
|
||||
app.canvas.ds.offset = [25, 50]
|
||||
app.canvas.ds.scale = 0.5
|
||||
|
||||
await useWorkflowService().exportWorkflow('fallback.json', 'workflow')
|
||||
|
||||
expect(mockDownloadBlob.mock.calls[0][0]).toBe('current')
|
||||
const blob = mockDownloadBlob.mock.calls[0][1] as Blob
|
||||
const exported = JSON.parse(await blob.text()) as ComfyWorkflowJSON
|
||||
expect(exported.extra?.ds).toEqual({
|
||||
scale: 0.5,
|
||||
offset: [25, 50]
|
||||
})
|
||||
})
|
||||
|
||||
it('cancels prompted exports when the user dismisses the filename dialog', async () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => key === 'Comfy.PromptFilename'
|
||||
)
|
||||
mockPrompt.mockResolvedValue(null)
|
||||
|
||||
await useWorkflowService().exportWorkflow('workflow.json', 'output')
|
||||
|
||||
expect(mockDownloadBlob).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('appends json to prompted export filenames', async () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => key === 'Comfy.PromptFilename'
|
||||
)
|
||||
mockPrompt.mockResolvedValue('custom-name')
|
||||
|
||||
await useWorkflowService().exportWorkflow('workflow.json', 'output')
|
||||
|
||||
expect(mockDownloadBlob.mock.calls[0][0]).toBe('custom-name.json')
|
||||
})
|
||||
|
||||
it('keeps prompted export filenames that already end in json', async () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => key === 'Comfy.PromptFilename'
|
||||
)
|
||||
mockPrompt.mockResolvedValue('custom-name.JSON')
|
||||
|
||||
await useWorkflowService().exportWorkflow('workflow.json', 'output')
|
||||
|
||||
expect(mockDownloadBlob.mock.calls[0][0]).toBe('custom-name.JSON')
|
||||
})
|
||||
})
|
||||
|
||||
describe('showPendingWarnings', () => {
|
||||
@@ -227,6 +307,47 @@ describe('useWorkflowService', () => {
|
||||
).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('restores cached missing model and media warnings', () => {
|
||||
const modelCandidates = [
|
||||
{
|
||||
nodeId: '1',
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
widgetName: 'ckpt_name',
|
||||
isAssetSupported: false,
|
||||
name: 'missing.safetensors',
|
||||
isMissing: true
|
||||
}
|
||||
]
|
||||
const mediaCandidates = [
|
||||
{
|
||||
nodeId: '2',
|
||||
nodeType: 'LoadImage',
|
||||
widgetName: 'image',
|
||||
mediaType: 'image' as const,
|
||||
name: 'missing.png',
|
||||
isMissing: true
|
||||
}
|
||||
]
|
||||
const workflow = createWorkflow({
|
||||
missingModelCandidates: modelCandidates,
|
||||
missingMediaCandidates: mediaCandidates
|
||||
})
|
||||
|
||||
useWorkflowService().showPendingWarnings(workflow)
|
||||
|
||||
expect(useMissingModelStore().setMissingModels).toHaveBeenCalledWith(
|
||||
modelCandidates
|
||||
)
|
||||
expect(useMissingMediaStore().setMissingMedia).toHaveBeenCalledWith(
|
||||
mediaCandidates
|
||||
)
|
||||
expect(workflow.pendingWarnings).toEqual({
|
||||
missingNodeTypes: undefined,
|
||||
missingModelCandidates: modelCandidates,
|
||||
missingMediaCandidates: mediaCandidates
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT call showErrorOverlay when silent is true even with missing nodes', () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => {
|
||||
@@ -394,6 +515,29 @@ describe('useWorkflowService', () => {
|
||||
consoleErrorSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('does nothing when no workflow is active', () => {
|
||||
workflowStore.activeWorkflow = null
|
||||
|
||||
useWorkflowService().beforeLoadNewGraph()
|
||||
|
||||
expect(draftStoreMocks.saveDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not persist a draft when the active workflow has no active state', () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation((key: string) => {
|
||||
return key === 'Comfy.Workflow.Persist'
|
||||
})
|
||||
const activeWorkflow = createModeTestWorkflow({
|
||||
path: 'workflows/test.json'
|
||||
})
|
||||
activeWorkflow.changeTracker = undefined as unknown as ChangeTracker
|
||||
workflowStore.activeWorkflow = activeWorkflow
|
||||
|
||||
useWorkflowService().beforeLoadNewGraph()
|
||||
|
||||
expect(draftStoreMocks.saveDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('openWorkflow deferred warnings', () => {
|
||||
@@ -485,6 +629,157 @@ describe('useWorkflowService', () => {
|
||||
useMissingNodesErrorStore().surfaceMissingNodes
|
||||
).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('does not reload the already active workflow unless forced', async () => {
|
||||
const workflow = createWorkflow(null, { loadable: true })
|
||||
vi.mocked(workflowStore.isActive).mockReturnValue(true)
|
||||
|
||||
await useWorkflowService().openWorkflow(workflow)
|
||||
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('loads remote workflow data before opening unloaded workflows', async () => {
|
||||
const workflow = createWorkflow(null, { loadable: true })
|
||||
Object.assign(workflow, { isLoaded: false })
|
||||
workflow.load = vi.fn().mockResolvedValue(workflow)
|
||||
|
||||
await useWorkflowService().openWorkflow(workflow)
|
||||
|
||||
expect(workflow.load).toHaveBeenCalledOnce()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
true,
|
||||
true,
|
||||
workflow,
|
||||
expect.objectContaining({
|
||||
skipAssetScans: false
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('workflow navigation helpers', () => {
|
||||
let workflowStore: ReturnType<typeof useWorkflowStore>
|
||||
|
||||
beforeEach(() => {
|
||||
workflowStore = useWorkflowStore()
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('reloads the active workflow with force', async () => {
|
||||
const active = createWorkflow(null, { loadable: true })
|
||||
workflowStore.activeWorkflow = active as LoadedComfyWorkflow
|
||||
|
||||
await useWorkflowService().reloadCurrentWorkflow()
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
true,
|
||||
true,
|
||||
active,
|
||||
expect.objectContaining({
|
||||
skipAssetScans: false
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when reloading without an active workflow', async () => {
|
||||
workflowStore.activeWorkflow = null
|
||||
|
||||
await useWorkflowService().reloadCurrentWorkflow()
|
||||
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('loads default and blank workflows through app loadGraphData', async () => {
|
||||
const service = useWorkflowService()
|
||||
|
||||
await service.loadDefaultWorkflow()
|
||||
await service.loadBlankWorkflow()
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('opens neighboring workflows when available', async () => {
|
||||
const next = createWorkflow(null, { loadable: true })
|
||||
const previous = createWorkflow(null, { loadable: true })
|
||||
vi.mocked(workflowStore.openedWorkflowIndexShift)
|
||||
.mockReturnValueOnce(next)
|
||||
.mockReturnValueOnce(previous)
|
||||
|
||||
const service = useWorkflowService()
|
||||
await service.loadNextOpenedWorkflow()
|
||||
await service.loadPreviousOpenedWorkflow()
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('does nothing when no neighboring workflow is available', async () => {
|
||||
vi.mocked(workflowStore.openedWorkflowIndexShift).mockReturnValue(null)
|
||||
|
||||
const service = useWorkflowService()
|
||||
await service.loadNextOpenedWorkflow()
|
||||
await service.loadPreviousOpenedWorkflow()
|
||||
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteWorkflow', () => {
|
||||
let workflowStore: ReturnType<typeof useWorkflowStore>
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia())
|
||||
workflowStore = useWorkflowStore()
|
||||
})
|
||||
|
||||
it('returns false when delete confirmation is declined', async () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => key === 'Comfy.Workflow.ConfirmDelete'
|
||||
)
|
||||
mockConfirm.mockResolvedValue(false)
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/delete-me.json'
|
||||
})
|
||||
|
||||
const deleted = await useWorkflowService().deleteWorkflow(workflow)
|
||||
|
||||
expect(deleted).toBe(false)
|
||||
expect(workflowStore.deleteWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('deletes silently without showing a toast', async () => {
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/delete-silent.json'
|
||||
})
|
||||
|
||||
const deleted = await useWorkflowService().deleteWorkflow(workflow, true)
|
||||
|
||||
expect(deleted).toBe(true)
|
||||
expect(workflowStore.deleteWorkflow).toHaveBeenCalledWith(workflow)
|
||||
expect(useToastStore().add).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows a toast after confirmed visible deletion', async () => {
|
||||
vi.spyOn(useSettingStore(), 'get').mockImplementation(
|
||||
(key: string): boolean => key === 'Comfy.Workflow.ConfirmDelete'
|
||||
)
|
||||
mockConfirm.mockResolvedValue(true)
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/delete-visible.json'
|
||||
})
|
||||
|
||||
const deleted = await useWorkflowService().deleteWorkflow(workflow)
|
||||
|
||||
expect(deleted).toBe(true)
|
||||
expect(workflowStore.deleteWorkflow).toHaveBeenCalledWith(workflow)
|
||||
expect(useToastStore().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'info'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('saveWorkflow', () => {
|
||||
@@ -521,6 +816,50 @@ describe('useWorkflowService', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('duplicateWorkflow', () => {
|
||||
it('loads unloaded workflows and assigns a new id to duplicated state', async () => {
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/source.json',
|
||||
loaded: false
|
||||
})
|
||||
workflow.load = vi.fn().mockImplementation(async () => {
|
||||
workflow.changeTracker = createMockChangeTracker({
|
||||
activeState: {
|
||||
...makeWorkflowData({ duplicated: true }),
|
||||
id: 'old-id'
|
||||
}
|
||||
})
|
||||
return workflow
|
||||
})
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
|
||||
await useWorkflowService().duplicateWorkflow(workflow)
|
||||
|
||||
expect(workflow.load).toHaveBeenCalledOnce()
|
||||
const duplicatedState = vi.mocked(app.loadGraphData).mock
|
||||
.calls[0][0] as ComfyWorkflowJSON
|
||||
expect(duplicatedState.id).not.toBe('old-id')
|
||||
expect(vi.mocked(app.loadGraphData).mock.calls[0][3]).toBe(
|
||||
'source (Copy)'
|
||||
)
|
||||
})
|
||||
|
||||
it('duplicates empty workflow state without assigning an id', async () => {
|
||||
const workflow = {
|
||||
isLoaded: true,
|
||||
activeState: null,
|
||||
isPersisted: false,
|
||||
filename: 'source (2)'
|
||||
} as ComfyWorkflow
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
|
||||
await useWorkflowService().duplicateWorkflow(workflow)
|
||||
|
||||
expect(vi.mocked(app.loadGraphData).mock.calls[0][0]).toBeNull()
|
||||
expect(vi.mocked(app.loadGraphData).mock.calls[0][3]).toBe('source')
|
||||
})
|
||||
})
|
||||
|
||||
describe('closeWorkflow', () => {
|
||||
let workflowStore: ReturnType<typeof useWorkflowStore>
|
||||
let service: ReturnType<typeof useWorkflowService>
|
||||
@@ -544,6 +883,48 @@ describe('useWorkflowService', () => {
|
||||
expect(closed).toBe(false)
|
||||
expect(workflowStore.closeWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns false when dirty close confirmation is cancelled', async () => {
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/dirty.json'
|
||||
})
|
||||
workflow.isModified = true
|
||||
mockConfirm.mockResolvedValue(null)
|
||||
|
||||
const closed = await service.closeWorkflow(workflow)
|
||||
|
||||
expect(closed).toBe(false)
|
||||
expect(workflowStore.closeWorkflow).not.toHaveBeenCalled()
|
||||
expect(draftStoreMocks.removeDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('opens the most recent workflow after closing the active workflow', async () => {
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/active.json'
|
||||
})
|
||||
const recent = createWorkflow(null, {
|
||||
loadable: true,
|
||||
path: 'workflows/recent.json'
|
||||
})
|
||||
Object.assign(workflowStore, { openWorkflows: [workflow, recent] })
|
||||
vi.mocked(workflowStore.isActive).mockImplementation(
|
||||
(candidate) => candidate === workflow
|
||||
)
|
||||
vi.mocked(workflowStore.getMostRecentWorkflow).mockReturnValue(recent)
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
|
||||
const closed = await service.closeWorkflow(workflow)
|
||||
|
||||
expect(closed).toBe(true)
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
true,
|
||||
true,
|
||||
recent,
|
||||
expect.any(Object)
|
||||
)
|
||||
expect(workflowStore.closeWorkflow).toHaveBeenCalledWith(workflow)
|
||||
})
|
||||
})
|
||||
|
||||
describe('afterLoadNewGraph', () => {
|
||||
@@ -654,6 +1035,34 @@ describe('useWorkflowService', () => {
|
||||
expect(tempWorkflow.shareId).toBe('share-1')
|
||||
})
|
||||
|
||||
it('creates unnamed temporary workflows for null loads', async () => {
|
||||
vi.mocked(workflowStore.getWorkflowByPath).mockReturnValue(null)
|
||||
const tempWorkflow = createModeTestWorkflow({
|
||||
path: 'workflows/unsaved.json'
|
||||
})
|
||||
vi.mocked(workflowStore.createNewTemporary).mockReturnValue(tempWorkflow)
|
||||
vi.mocked(workflowStore.openWorkflow).mockResolvedValue(tempWorkflow)
|
||||
|
||||
await useWorkflowService().afterLoadNewGraph(null, makeWorkflowData())
|
||||
|
||||
expect(workflowStore.createNewTemporary).toHaveBeenCalledWith(
|
||||
undefined,
|
||||
expect.any(Object)
|
||||
)
|
||||
expect(tempWorkflow.initialMode).toBeNull()
|
||||
})
|
||||
|
||||
it('keeps existing initialMode when reusing a loaded workflow', async () => {
|
||||
existingWorkflow.initialMode = 'graph'
|
||||
|
||||
await useWorkflowService().afterLoadNewGraph(
|
||||
'repeat',
|
||||
makeWorkflowData({ linearMode: true })
|
||||
)
|
||||
|
||||
expect(existingWorkflow.initialMode).toBe('graph')
|
||||
})
|
||||
|
||||
it('preserves share attribution on repeated same-path loads', async () => {
|
||||
existingWorkflow.shareId = 'share-1'
|
||||
|
||||
@@ -1270,6 +1679,34 @@ describe('useWorkflowService', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('insertWorkflow', () => {
|
||||
it('pastes loaded workflow data and restores the previous clipboard', async () => {
|
||||
const service = useWorkflowService()
|
||||
const workflow = createModeTestWorkflow({
|
||||
path: 'workflows/insert.json'
|
||||
})
|
||||
workflow.load = vi.fn().mockResolvedValue({
|
||||
initialState: makeWorkflowData()
|
||||
})
|
||||
const pasteFromClipboard = vi.fn()
|
||||
Object.assign(app.canvas, { pasteFromClipboard })
|
||||
vi.spyOn(HTMLCanvasElement.prototype, 'getContext').mockReturnValue(
|
||||
createMockCanvasRenderingContext2D() as unknown as ReturnType<
|
||||
HTMLCanvasElement['getContext']
|
||||
>
|
||||
)
|
||||
localStorage.setItem('litegrapheditor_clipboard', 'previous')
|
||||
|
||||
await service.insertWorkflow(workflow, { position: [10, 20] })
|
||||
|
||||
expect(workflow.load).toHaveBeenCalled()
|
||||
expect(pasteFromClipboard).toHaveBeenCalledWith({
|
||||
position: [10, 20]
|
||||
})
|
||||
expect(localStorage.getItem('litegrapheditor_clipboard')).toBe('previous')
|
||||
})
|
||||
})
|
||||
|
||||
describe('saveWorkflow', () => {
|
||||
let workflowStore: ReturnType<typeof useWorkflowStore>
|
||||
let toastStore: ReturnType<typeof useToastStore>
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { settings, workflows } = vi.hoisted(() => ({
|
||||
settings: { tabsPosition: 'Sidebar' },
|
||||
workflows: { openWorkflows: [] as unknown[] }
|
||||
}))
|
||||
|
||||
vi.mock('@/components/sidebar/tabs/WorkflowsSidebarTab.vue', () => ({
|
||||
default: {}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: vi.fn(() => settings.tabsPosition)
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: () => ({
|
||||
openWorkflows: workflows.openWorkflows
|
||||
})
|
||||
}))
|
||||
|
||||
describe('useWorkflowsSidebarTab', () => {
|
||||
beforeEach(() => {
|
||||
settings.tabsPosition = 'Sidebar'
|
||||
workflows.openWorkflows = []
|
||||
})
|
||||
|
||||
it('hides the badge when workflow tabs are not in the sidebar', async () => {
|
||||
settings.tabsPosition = 'Topbar'
|
||||
workflows.openWorkflows = [{ path: 'a' }]
|
||||
const { useWorkflowsSidebarTab } = await import('./useWorkflowsSidebarTab')
|
||||
|
||||
const sidebarTab = useWorkflowsSidebarTab()
|
||||
|
||||
expect((sidebarTab.iconBadge as () => string | null)()).toBeNull()
|
||||
})
|
||||
|
||||
it('hides the badge when no workflows are open', async () => {
|
||||
const { useWorkflowsSidebarTab } = await import('./useWorkflowsSidebarTab')
|
||||
|
||||
const sidebarTab = useWorkflowsSidebarTab()
|
||||
|
||||
expect((sidebarTab.iconBadge as () => string | null)()).toBeNull()
|
||||
})
|
||||
|
||||
it('shows the open workflow count for sidebar tabs', async () => {
|
||||
workflows.openWorkflows = [{ path: 'a' }, { path: 'b' }]
|
||||
const { useWorkflowsSidebarTab } = await import('./useWorkflowsSidebarTab')
|
||||
|
||||
const sidebarTab = useWorkflowsSidebarTab()
|
||||
|
||||
expect((sidebarTab.iconBadge as () => string | null)()).toBe('2')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,81 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyWorkflow } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
|
||||
vi.mock('@/scripts/app', () => ({ app: {} }))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: () => {},
|
||||
getUserData: async () => ({ status: 404 }),
|
||||
storeUserData: async () => {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/renderer/core/thumbnail/useWorkflowThumbnail', () => ({
|
||||
useWorkflowThumbnail: () => ({
|
||||
moveWorkflowThumbnail: () => {},
|
||||
clearThumbnail: () => {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/persistence/stores/workflowDraftStoreV2', () => ({
|
||||
useWorkflowDraftStoreV2: () => ({
|
||||
getDraft: () => null,
|
||||
saveDraft: () => {},
|
||||
deleteDraft: () => {}
|
||||
})
|
||||
}))
|
||||
|
||||
interface WorkflowFlags {
|
||||
path: string
|
||||
isPersisted?: boolean
|
||||
isModified?: boolean
|
||||
}
|
||||
|
||||
function wf(flags: WorkflowFlags): ComfyWorkflow {
|
||||
return flags as unknown as ComfyWorkflow
|
||||
}
|
||||
|
||||
function paths(workflows: ComfyWorkflow[]) {
|
||||
return workflows.map((w) => w.path)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
describe('workflowStore workflow lists', () => {
|
||||
it('persistedWorkflows excludes unpersisted and subgraph entries', () => {
|
||||
const store = useWorkflowStore()
|
||||
store.attachWorkflow(wf({ path: 'a.json', isPersisted: true }))
|
||||
store.attachWorkflow(wf({ path: 'b.json', isPersisted: false }))
|
||||
store.attachWorkflow(wf({ path: 'subgraphs/c.json', isPersisted: true }))
|
||||
|
||||
expect(paths(store.persistedWorkflows)).toEqual(['a.json'])
|
||||
})
|
||||
|
||||
it('modifiedWorkflows includes only modified workflows', () => {
|
||||
const store = useWorkflowStore()
|
||||
store.attachWorkflow(wf({ path: 'a.json', isModified: true }))
|
||||
store.attachWorkflow(wf({ path: 'b.json', isModified: false }))
|
||||
|
||||
expect(paths(store.modifiedWorkflows)).toEqual(['a.json'])
|
||||
})
|
||||
|
||||
it('bookmarkedWorkflows is empty when nothing is bookmarked', () => {
|
||||
const store = useWorkflowStore()
|
||||
store.attachWorkflow(wf({ path: 'a.json' }))
|
||||
|
||||
expect(store.bookmarkedWorkflows).toEqual([])
|
||||
})
|
||||
|
||||
it('openedWorkflowIndexShift returns null when no workflow is active', () => {
|
||||
const store = useWorkflowStore()
|
||||
store.attachWorkflow(wf({ path: 'a.json' }), 0)
|
||||
|
||||
expect(store.openedWorkflowIndexShift(1)).toBeNull()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,87 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { Subgraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { createNodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
vi.mock('@/scripts/app', () => ({ app: {} }))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: () => {},
|
||||
getUserData: async () => ({ status: 404 }),
|
||||
storeUserData: async () => {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/renderer/core/thumbnail/useWorkflowThumbnail', () => ({
|
||||
useWorkflowThumbnail: () => ({
|
||||
moveWorkflowThumbnail: () => {},
|
||||
clearThumbnail: () => {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/persistence/stores/workflowDraftStoreV2', () => ({
|
||||
useWorkflowDraftStoreV2: () => ({
|
||||
getDraft: () => null,
|
||||
saveDraft: () => {},
|
||||
deleteDraft: () => {}
|
||||
})
|
||||
}))
|
||||
|
||||
const SUBGRAPH_UUID = 'a1b2c3d4-e5f6-7890-abcd-ef1234567890'
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
describe('workflowStore node locator translation', () => {
|
||||
it('treats a node as a root-graph node when no subgraph is active', () => {
|
||||
const store = useWorkflowStore()
|
||||
expect(store.nodeIdToNodeLocatorId(toNodeId(5))).toBe('5')
|
||||
})
|
||||
|
||||
it('prefixes the locator with an explicit subgraph uuid', () => {
|
||||
const store = useWorkflowStore()
|
||||
const subgraph = { id: SUBGRAPH_UUID } as unknown as Subgraph
|
||||
|
||||
expect(store.nodeIdToNodeLocatorId(toNodeId(5), subgraph)).toBe(
|
||||
`${SUBGRAPH_UUID}:5`
|
||||
)
|
||||
})
|
||||
|
||||
it('derives a locator from a node based on whether its graph is a subgraph', () => {
|
||||
const store = useWorkflowStore()
|
||||
const rootNode = { id: toNodeId(7), graph: {} } as unknown as LGraphNode
|
||||
expect(store.nodeToNodeLocatorId(rootNode)).toBe('7')
|
||||
})
|
||||
|
||||
it('extracts the local node id from a locator', () => {
|
||||
const store = useWorkflowStore()
|
||||
expect(
|
||||
store.nodeLocatorIdToNodeId(
|
||||
createNodeLocatorId(SUBGRAPH_UUID, toNodeId(5))
|
||||
)
|
||||
).toBe(toNodeId(5))
|
||||
expect(
|
||||
store.nodeLocatorIdToNodeId(createNodeLocatorId(null, toNodeId(9)))
|
||||
).toBe(toNodeId(9))
|
||||
})
|
||||
|
||||
it('round-trips a root node id through locator translation', () => {
|
||||
const store = useWorkflowStore()
|
||||
const locator = store.nodeIdToNodeLocatorId(toNodeId(42))
|
||||
expect(store.nodeLocatorIdToNodeId(locator)).toBe(toNodeId(42))
|
||||
})
|
||||
|
||||
it('maps a root locator to a single-segment execution id', () => {
|
||||
const store = useWorkflowStore()
|
||||
expect(
|
||||
store.nodeLocatorIdToNodeExecutionId(
|
||||
createNodeLocatorId(null, toNodeId(5))
|
||||
)
|
||||
).toBe('5')
|
||||
})
|
||||
})
|
||||
@@ -21,6 +21,7 @@ import { defaultGraph, defaultGraphJSON } from '@/scripts/defaultGraph'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import type { NodeId } from '@/types/nodeId'
|
||||
import { createNodeLocatorId } from '@/types/nodeIdentification'
|
||||
import type { NodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { isSubgraph } from '@/utils/typeGuardUtil'
|
||||
import {
|
||||
createMockCanvas,
|
||||
@@ -205,6 +206,21 @@ describe('useWorkflowStore', () => {
|
||||
expect(workflow.content).toBeNull()
|
||||
expect(workflow.originalContent).toBeNull()
|
||||
})
|
||||
|
||||
it('should sync workflows from a nested directory', async () => {
|
||||
await syncRemoteWorkflowsWithMeta([
|
||||
{ path: 'nested.json', modified: 100, size: 1 }
|
||||
])
|
||||
|
||||
await store.syncWorkflows('subdir')
|
||||
|
||||
expect(api.listUserDataFullInfo).toHaveBeenLastCalledWith(
|
||||
'workflows/subdir'
|
||||
)
|
||||
expect(
|
||||
store.getWorkflowByPath('workflows/subdir/nested.json')
|
||||
).not.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createTemporary', () => {
|
||||
@@ -246,6 +262,12 @@ describe('useWorkflowStore', () => {
|
||||
expect(state.id.length).toBeGreaterThan(0)
|
||||
expect(workflowDataWithoutId.id).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should create a new temporary workflow with the default path', () => {
|
||||
const workflow = store.createNewTemporary()
|
||||
|
||||
expect(workflow.path).toBe('workflows/Unsaved Workflow.json')
|
||||
})
|
||||
})
|
||||
|
||||
describe('openWorkflow', () => {
|
||||
@@ -484,6 +506,28 @@ describe('useWorkflowStore', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('openedWorkflowIndexShift', () => {
|
||||
it('returns null when there is no active workflow', () => {
|
||||
expect(store.openedWorkflowIndexShift(1)).toBeNull()
|
||||
})
|
||||
|
||||
it('wraps around open workflow tabs', async () => {
|
||||
await syncRemoteWorkflows(['a.json', 'b.json'])
|
||||
const workflowA = store.getWorkflowByPath('workflows/a.json')!
|
||||
const workflowB = store.getWorkflowByPath('workflows/b.json')!
|
||||
vi.mocked(api.getUserData).mockResolvedValue({
|
||||
status: 200,
|
||||
text: () => Promise.resolve(defaultGraphJSON)
|
||||
} as Response)
|
||||
|
||||
await store.openWorkflow(workflowA)
|
||||
await store.openWorkflow(workflowB)
|
||||
|
||||
expect(store.openedWorkflowIndexShift(1)?.path).toBe(workflowA.path)
|
||||
expect(store.openedWorkflowIndexShift(-1)?.path).toBe(workflowA.path)
|
||||
})
|
||||
})
|
||||
|
||||
describe('renameWorkflow', () => {
|
||||
it('should rename workflow and update bookmarks', async () => {
|
||||
const workflow = store.createTemporary('dir/test.json')
|
||||
@@ -556,6 +600,17 @@ describe('useWorkflowStore', () => {
|
||||
expect(bookmarkStore.isBookmarked(workflow.path)).toBe(false)
|
||||
expect(bookmarkStore.isBookmarked('test.json')).toBe(false)
|
||||
})
|
||||
|
||||
it('should reset busy state when rename fails', async () => {
|
||||
const workflow = store.createTemporary('test.json')
|
||||
vi.spyOn(workflow, 'rename').mockRejectedValue(new Error('rename failed'))
|
||||
|
||||
await expect(
|
||||
store.renameWorkflow(workflow, 'workflows/renamed.json')
|
||||
).rejects.toThrow('rename failed')
|
||||
|
||||
expect(store.isBusy).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('closeWorkflow', () => {
|
||||
@@ -568,6 +623,17 @@ describe('useWorkflowStore', () => {
|
||||
expect(store.isOpen(workflow)).toBe(false)
|
||||
expect(store.getWorkflowByPath(workflow.path)).toBeNull()
|
||||
})
|
||||
|
||||
it('should unload persisted workflows on close', async () => {
|
||||
await syncRemoteWorkflows(['a.json'])
|
||||
const workflow = store.getWorkflowByPath('workflows/a.json')!
|
||||
const unloadSpy = vi.spyOn(workflow, 'unload')
|
||||
|
||||
await store.closeWorkflow(workflow)
|
||||
|
||||
expect(unloadSpy).toHaveBeenCalled()
|
||||
expect(store.getWorkflowByPath(workflow.path)).toBe(workflow)
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteWorkflow', () => {
|
||||
@@ -603,6 +669,17 @@ describe('useWorkflowStore', () => {
|
||||
// Verify bookmark was removed
|
||||
expect(bookmarkStore.isBookmarked(workflow.path)).toBe(false)
|
||||
})
|
||||
|
||||
it('should reset busy state when delete fails', async () => {
|
||||
const workflow = store.createTemporary('test.json')
|
||||
vi.spyOn(workflow, 'delete').mockRejectedValue(new Error('delete failed'))
|
||||
|
||||
await expect(store.deleteWorkflow(workflow)).rejects.toThrow(
|
||||
'delete failed'
|
||||
)
|
||||
|
||||
expect(store.isBusy).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('save', () => {
|
||||
@@ -662,6 +739,15 @@ describe('useWorkflowStore', () => {
|
||||
expect(workflow.changeTracker!.reset).toHaveBeenCalled()
|
||||
expect(workflow.isModified).toBe(false)
|
||||
})
|
||||
|
||||
it('should reset busy state when save fails', async () => {
|
||||
const workflow = store.createTemporary('test.json')
|
||||
vi.spyOn(workflow, 'save').mockRejectedValue(new Error('save failed'))
|
||||
|
||||
await expect(store.saveWorkflow(workflow)).rejects.toThrow('save failed')
|
||||
|
||||
expect(store.isBusy).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('saveAs', () => {
|
||||
@@ -899,6 +985,33 @@ describe('useWorkflowStore', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeToNodeLocatorId', () => {
|
||||
it('should include subgraph IDs for nodes inside subgraphs', () => {
|
||||
const subgraph = fromPartial<Subgraph>({
|
||||
id: '22222222-3333-4444-8555-666666666666'
|
||||
})
|
||||
vi.mocked(isSubgraph).mockImplementation(
|
||||
(obj): obj is Subgraph => obj === subgraph
|
||||
)
|
||||
|
||||
const node = createMockLGraphNode({
|
||||
id: toNodeId(77),
|
||||
graph: subgraph
|
||||
})
|
||||
|
||||
expect(store.nodeToNodeLocatorId(node)).toBe(
|
||||
'22222222-3333-4444-8555-666666666666:77'
|
||||
)
|
||||
})
|
||||
|
||||
it('should return root locators for nodes outside subgraphs', () => {
|
||||
vi.mocked(isSubgraph).mockImplementation(() => false)
|
||||
const node = createMockLGraphNode({ id: toNodeId(77) })
|
||||
|
||||
expect(store.nodeToNodeLocatorId(node)).toBe('77')
|
||||
})
|
||||
})
|
||||
|
||||
describe('executionIdToCurrentId', () => {
|
||||
it('should convert an execution ID to the active subgraph node ID', () => {
|
||||
const result = store.executionIdToCurrentId('123:456')
|
||||
@@ -914,6 +1027,16 @@ describe('useWorkflowStore', () => {
|
||||
expect(() => store.executionIdToCurrentId('123::456')).not.toThrow()
|
||||
expect(store.executionIdToCurrentId('123::456')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return a root-level ID as-is only when no subgraph is active', () => {
|
||||
store.activeSubgraph = undefined
|
||||
expect(store.executionIdToCurrentId('42')).toBe('42')
|
||||
expect(store.executionIdToCurrentId('123:456')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined for a root-level ID while a subgraph is active', () => {
|
||||
expect(store.executionIdToCurrentId('42')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
describe('nodeLocatorIdToNodeId', () => {
|
||||
it('should extract node ID from NodeLocatorId', () => {
|
||||
@@ -950,6 +1073,14 @@ describe('useWorkflowStore', () => {
|
||||
})
|
||||
|
||||
describe('nodeLocatorIdToNodeExecutionId', () => {
|
||||
it('should return null for invalid locator IDs', () => {
|
||||
const result = store.nodeLocatorIdToNodeExecutionId(
|
||||
fromAny<NodeLocatorId, string>('bad:123')
|
||||
)
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should convert NodeLocatorId to execution ID', () => {
|
||||
vi.mocked(isSubgraph).mockImplementation((obj): obj is Subgraph => {
|
||||
return obj === store.activeSubgraph
|
||||
@@ -980,6 +1111,27 @@ describe('useWorkflowStore', () => {
|
||||
)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null when the target subgraph is not on the path', () => {
|
||||
vi.mocked(isSubgraph).mockImplementation((obj): obj is Subgraph => {
|
||||
return obj === store.activeSubgraph
|
||||
})
|
||||
const unrelatedSubgraph = fromPartial<Subgraph>({
|
||||
id: '33333333-4444-4555-8666-777777777777',
|
||||
_nodes: [],
|
||||
nodes: []
|
||||
})
|
||||
|
||||
const result = store.nodeLocatorIdToNodeExecutionId(
|
||||
createNodeLocatorId(
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
|
||||
toNodeId(456)
|
||||
),
|
||||
unrelatedSubgraph
|
||||
)
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1070,6 +1222,18 @@ describe('useWorkflowStore', () => {
|
||||
const mostRecent = store.getMostRecentWorkflow()
|
||||
expect(mostRecent).toBeNull()
|
||||
})
|
||||
|
||||
it('should trim activation history to the most recent entries', async () => {
|
||||
const workflows = Array.from({ length: 34 }, (_, index) =>
|
||||
store.createTemporary(`history-${index}.json`)
|
||||
)
|
||||
|
||||
for (const workflow of workflows) {
|
||||
await store.openWorkflow(workflow)
|
||||
}
|
||||
|
||||
expect(store.getMostRecentWorkflow()?.path).toBe(workflows[32].path)
|
||||
})
|
||||
})
|
||||
|
||||
describe('closeWorkflow draft cleanup', () => {
|
||||
@@ -1100,4 +1264,42 @@ describe('useWorkflowStore', () => {
|
||||
expect(draftStore.getDraft(workflow.path)).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('workflow bookmarks', () => {
|
||||
it('loads no bookmarks when the index response is not found', async () => {
|
||||
vi.mocked(api.getUserData).mockResolvedValueOnce({
|
||||
status: 404,
|
||||
json: () => Promise.resolve({})
|
||||
} as Response)
|
||||
|
||||
await bookmarkStore.loadBookmarks()
|
||||
|
||||
expect(bookmarkStore.isBookmarked('workflows/a.json')).toBe(false)
|
||||
})
|
||||
|
||||
it('loads an empty bookmark set from a sparse index response', async () => {
|
||||
vi.mocked(api.getUserData).mockResolvedValueOnce({
|
||||
status: 200,
|
||||
json: () => Promise.resolve(null)
|
||||
} as Response)
|
||||
|
||||
await bookmarkStore.loadBookmarks()
|
||||
|
||||
expect(bookmarkStore.isBookmarked('workflows/a.json')).toBe(false)
|
||||
})
|
||||
|
||||
it('does not save when setting an existing bookmark state', async () => {
|
||||
await bookmarkStore.setBookmarked('workflows/a.json', false)
|
||||
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('toggles bookmarks on and off', async () => {
|
||||
await bookmarkStore.toggleBookmarked('workflows/a.json')
|
||||
expect(bookmarkStore.isBookmarked('workflows/a.json')).toBe(true)
|
||||
|
||||
await bookmarkStore.toggleBookmarked('workflows/a.json')
|
||||
expect(bookmarkStore.isBookmarked('workflows/a.json')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
99
src/platform/workflow/management/stores/workflowTabs.test.ts
Normal file
99
src/platform/workflow/management/stores/workflowTabs.test.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyWorkflow } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
|
||||
vi.mock('@/scripts/app', () => ({ app: {} }))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: () => {},
|
||||
getUserData: async () => ({ status: 404 }),
|
||||
storeUserData: async () => {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/renderer/core/thumbnail/useWorkflowThumbnail', () => ({
|
||||
useWorkflowThumbnail: () => ({
|
||||
moveWorkflowThumbnail: () => {},
|
||||
clearThumbnail: () => {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/persistence/stores/workflowDraftStoreV2', () => ({
|
||||
useWorkflowDraftStoreV2: () => ({
|
||||
getDraft: () => null,
|
||||
saveDraft: () => {},
|
||||
deleteDraft: () => {}
|
||||
})
|
||||
}))
|
||||
|
||||
function wf(path: string): ComfyWorkflow {
|
||||
return { path } as unknown as ComfyWorkflow
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
describe('workflowStore tab management', () => {
|
||||
it('attaches workflows into the lookup and finds them by path', () => {
|
||||
const store = useWorkflowStore()
|
||||
const a = wf('a.json')
|
||||
store.attachWorkflow(a)
|
||||
|
||||
// Pinia wraps stored objects in reactive proxies, so compare structurally.
|
||||
expect(store.getWorkflowByPath('a.json')).toEqual(a)
|
||||
expect(store.getWorkflowByPath('missing.json')).toBeNull()
|
||||
expect(store.workflows).toContainEqual(a)
|
||||
})
|
||||
|
||||
it('tracks which workflows are open', () => {
|
||||
const store = useWorkflowStore()
|
||||
const open = wf('open.json')
|
||||
const closed = wf('closed.json')
|
||||
store.attachWorkflow(open, 0)
|
||||
store.attachWorkflow(closed)
|
||||
|
||||
expect(store.isOpen(open)).toBe(true)
|
||||
expect(store.isOpen(closed)).toBe(false)
|
||||
expect(store.openWorkflows).toEqual([open])
|
||||
})
|
||||
|
||||
it('reorders open workflow tabs', () => {
|
||||
const store = useWorkflowStore()
|
||||
const a = wf('a.json')
|
||||
const b = wf('b.json')
|
||||
const c = wf('c.json')
|
||||
store.attachWorkflow(a, 0)
|
||||
store.attachWorkflow(b, 1)
|
||||
store.attachWorkflow(c, 2)
|
||||
|
||||
store.reorderWorkflows(0, 2)
|
||||
|
||||
expect(store.openWorkflows).toEqual([b, c, a])
|
||||
})
|
||||
|
||||
it('opens background workflows on the requested side, ignoring unknown paths', () => {
|
||||
const store = useWorkflowStore()
|
||||
const left = wf('left.json')
|
||||
const mid = wf('mid.json')
|
||||
const right = wf('right.json')
|
||||
store.attachWorkflow(left)
|
||||
store.attachWorkflow(mid, 0)
|
||||
store.attachWorkflow(right)
|
||||
|
||||
store.openWorkflowsInBackground({
|
||||
left: ['left.json', 'unknown.json'],
|
||||
right: ['right.json']
|
||||
})
|
||||
|
||||
expect(store.openWorkflows).toEqual([left, mid, right])
|
||||
})
|
||||
|
||||
it('reports no active workflow before one is opened', () => {
|
||||
const store = useWorkflowStore()
|
||||
expect(store.isActive(wf('a.json'))).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -1,3 +1,4 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { DraftIndexV2, DraftPayloadV2 } from './draftTypes'
|
||||
@@ -17,6 +18,18 @@ import {
|
||||
writePayload
|
||||
} from './storageIO'
|
||||
|
||||
function createStorageStub(overrides: Partial<Storage> = {}): Storage {
|
||||
return fromAny<Storage, unknown>({
|
||||
length: 0,
|
||||
clear: vi.fn(),
|
||||
getItem: vi.fn(() => null),
|
||||
key: vi.fn(() => null),
|
||||
removeItem: vi.fn(),
|
||||
setItem: vi.fn(),
|
||||
...overrides
|
||||
})
|
||||
}
|
||||
|
||||
describe('storageIO', () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear()
|
||||
@@ -25,8 +38,11 @@ describe('storageIO', () => {
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
localStorage.clear()
|
||||
sessionStorage.clear()
|
||||
vi.resetModules()
|
||||
})
|
||||
|
||||
describe('index operations', () => {
|
||||
@@ -74,6 +90,60 @@ describe('storageIO', () => {
|
||||
)
|
||||
expect(readIndex(workspaceId)).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null for malformed index shapes', () => {
|
||||
for (const value of [
|
||||
null,
|
||||
42,
|
||||
{ v: 2, updatedAt: 'now', order: [], entries: {} },
|
||||
{ v: 2, updatedAt: 1, order: {}, entries: {} },
|
||||
{ v: 2, updatedAt: 1, order: [], entries: null }
|
||||
]) {
|
||||
localStorage.setItem(
|
||||
'Comfy.Workflow.DraftIndex.v2:test-workspace',
|
||||
JSON.stringify(value)
|
||||
)
|
||||
expect(readIndex(workspaceId)).toBeNull()
|
||||
}
|
||||
})
|
||||
|
||||
it('returns false for quota errors when writing an index', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
setItem: vi.fn(() => {
|
||||
throw new DOMException('full', 'QuotaExceededError')
|
||||
})
|
||||
})
|
||||
)
|
||||
const index: DraftIndexV2 = {
|
||||
v: 2,
|
||||
updatedAt: 1,
|
||||
order: [],
|
||||
entries: {}
|
||||
}
|
||||
|
||||
expect(writeIndex(workspaceId, index)).toBe(false)
|
||||
})
|
||||
|
||||
it('rethrows non-quota errors when writing an index', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
setItem: vi.fn(() => {
|
||||
throw new Error('storage failed')
|
||||
})
|
||||
})
|
||||
)
|
||||
const index: DraftIndexV2 = {
|
||||
v: 2,
|
||||
updatedAt: 1,
|
||||
order: [],
|
||||
entries: {}
|
||||
}
|
||||
|
||||
expect(() => writeIndex(workspaceId, index)).toThrow('storage failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('payload operations', () => {
|
||||
@@ -97,6 +167,45 @@ describe('storageIO', () => {
|
||||
expect(readPayload(workspaceId, 'missing')).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null for invalid payload JSON', () => {
|
||||
localStorage.setItem(
|
||||
'Comfy.Workflow.Draft.v2:test-workspace:abc12345',
|
||||
'invalid'
|
||||
)
|
||||
|
||||
expect(readPayload(workspaceId, draftKey)).toBeNull()
|
||||
})
|
||||
|
||||
it('returns false for quota errors when writing payloads', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
setItem: vi.fn(() => {
|
||||
throw new DOMException('full', 'NS_ERROR_DOM_QUOTA_REACHED')
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
writePayload(workspaceId, draftKey, { data: '{}', updatedAt: 1 })
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('rethrows non-quota errors when writing payloads', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
setItem: vi.fn(() => {
|
||||
throw new Error('storage failed')
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(() =>
|
||||
writePayload(workspaceId, draftKey, { data: '{}', updatedAt: 1 })
|
||||
).toThrow('storage failed')
|
||||
})
|
||||
|
||||
it('deletes payload', () => {
|
||||
const payload: DraftPayloadV2 = {
|
||||
data: '{}',
|
||||
@@ -109,6 +218,19 @@ describe('storageIO', () => {
|
||||
expect(readPayload(workspaceId, draftKey)).toBeNull()
|
||||
})
|
||||
|
||||
it('ignores delete errors', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
removeItem: vi.fn(() => {
|
||||
throw new Error('remove failed')
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(() => deletePayload(workspaceId, draftKey)).not.toThrow()
|
||||
})
|
||||
|
||||
it('deletes multiple payloads', () => {
|
||||
writePayload(workspaceId, 'key1', { data: '{}', updatedAt: 1 })
|
||||
writePayload(workspaceId, 'key2', { data: '{}', updatedAt: 2 })
|
||||
@@ -134,6 +256,20 @@ describe('storageIO', () => {
|
||||
expect(keys).toContain('abc')
|
||||
expect(keys).toContain('def')
|
||||
})
|
||||
|
||||
it('returns an empty list when key enumeration fails', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
length: 1,
|
||||
key: vi.fn(() => {
|
||||
throw new Error('key failed')
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(getPayloadKeys('ws-1')).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteOrphanPayloads', () => {
|
||||
@@ -279,6 +415,69 @@ describe('storageIO', () => {
|
||||
)
|
||||
expect(JSON.parse(raw!).workspaceId).toBe('ws-B')
|
||||
})
|
||||
|
||||
it('falls back to the last active path in localStorage', () => {
|
||||
const pointer = { workspaceId: 'ws-1', path: 'workflows/last.json' }
|
||||
localStorage.setItem(
|
||||
'Comfy.Workflow.LastActivePath:ws-1',
|
||||
JSON.stringify(pointer)
|
||||
)
|
||||
|
||||
expect(readActivePath('missing-client', 'ws-1')).toEqual(pointer)
|
||||
})
|
||||
|
||||
it('ignores invalid last active path pointers', () => {
|
||||
localStorage.setItem(
|
||||
'Comfy.Workflow.LastActivePath:ws-1',
|
||||
JSON.stringify({ workspaceId: 'ws-1', paths: [] })
|
||||
)
|
||||
|
||||
expect(readActivePath('missing-client', 'ws-1')).toBeNull()
|
||||
})
|
||||
|
||||
it('falls back to the last open paths in localStorage', () => {
|
||||
const pointer = {
|
||||
workspaceId: 'ws-1',
|
||||
paths: ['workflows/last.json'],
|
||||
activeIndex: 0
|
||||
}
|
||||
localStorage.setItem(
|
||||
'Comfy.Workflow.LastOpenPaths:ws-1',
|
||||
JSON.stringify(pointer)
|
||||
)
|
||||
|
||||
expect(readOpenPaths('missing-client', 'ws-1')).toEqual(pointer)
|
||||
})
|
||||
|
||||
it('ignores invalid migrated session pointers', () => {
|
||||
sessionStorage.setItem('Comfy.Workflow.OpenPaths:old-client', 'invalid')
|
||||
|
||||
expect(readOpenPaths('new-client', 'ws-1')).toBeNull()
|
||||
})
|
||||
|
||||
it('silently ignores pointer write failures', () => {
|
||||
const storage = createStorageStub({
|
||||
setItem: vi.fn(() => {
|
||||
throw new Error('write failed')
|
||||
})
|
||||
})
|
||||
vi.stubGlobal('localStorage', storage)
|
||||
vi.stubGlobal('sessionStorage', storage)
|
||||
|
||||
expect(() =>
|
||||
writeActivePath('client', {
|
||||
workspaceId: 'ws-1',
|
||||
path: 'workflows/a.json'
|
||||
})
|
||||
).not.toThrow()
|
||||
expect(() =>
|
||||
writeOpenPaths('client', {
|
||||
workspaceId: 'ws-1',
|
||||
paths: ['workflows/a.json'],
|
||||
activeIndex: 0
|
||||
})
|
||||
).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearAllV2Storage', () => {
|
||||
@@ -317,5 +516,57 @@ describe('storageIO', () => {
|
||||
).toBeNull()
|
||||
expect(sessionStorage.getItem('unrelated')).toBe('keep')
|
||||
})
|
||||
|
||||
it('ignores storage cleanup failures', () => {
|
||||
vi.stubGlobal(
|
||||
'localStorage',
|
||||
createStorageStub({
|
||||
length: 1,
|
||||
key: vi.fn(() => 'Comfy.Workflow.Draft.v2:ws-1:abc'),
|
||||
removeItem: vi.fn(() => {
|
||||
throw new Error('remove failed')
|
||||
})
|
||||
})
|
||||
)
|
||||
vi.stubGlobal(
|
||||
'sessionStorage',
|
||||
createStorageStub({
|
||||
length: 1,
|
||||
key: vi.fn(() => 'Comfy.Workflow.ActivePath:client-1'),
|
||||
removeItem: vi.fn(() => {
|
||||
throw new Error('remove failed')
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(() => clearAllV2Storage()).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('storage availability', () => {
|
||||
it('returns empty results and rejects writes after storage is marked unavailable', async () => {
|
||||
const storage = await import('./storageIO')
|
||||
|
||||
expect(storage.isStorageAvailable()).toBe(true)
|
||||
|
||||
storage.markStorageUnavailable()
|
||||
|
||||
expect(storage.isStorageAvailable()).toBe(false)
|
||||
expect(storage.readIndex('ws-1')).toBeNull()
|
||||
expect(storage.readPayload('ws-1', 'draft')).toBeNull()
|
||||
expect(storage.getPayloadKeys('ws-1')).toEqual([])
|
||||
expect(
|
||||
storage.writeIndex('ws-1', {
|
||||
v: 2,
|
||||
updatedAt: 1,
|
||||
order: [],
|
||||
entries: {}
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
storage.writePayload('ws-1', 'draft', { data: '{}', updatedAt: 1 })
|
||||
).toBe(false)
|
||||
expect(() => storage.clearAllV2Storage()).not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,6 +5,8 @@ import { createApp, defineComponent, nextTick } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { PERSIST_DEBOUNCE_MS } from '../base/draftTypes'
|
||||
import { migrateV1toV2 } from '../migration/migrateV1toV2'
|
||||
import { useWorkflowDraftStoreV2 } from '../stores/workflowDraftStoreV2'
|
||||
import { useWorkflowPersistenceV2 } from './useWorkflowPersistenceV2'
|
||||
|
||||
@@ -40,11 +42,15 @@ vi.mock('primevue/usetoast', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const sharedWorkflowLoaderMocks = vi.hoisted(() => ({
|
||||
load: vi.fn().mockResolvedValue('not-present')
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/platform/workflow/sharing/composables/useSharedWorkflowUrlLoader',
|
||||
() => ({
|
||||
useSharedWorkflowUrlLoader: () => ({
|
||||
loadSharedWorkflowFromUrl: vi.fn().mockResolvedValue('not-present')
|
||||
loadSharedWorkflowFromUrl: sharedWorkflowLoaderMocks.load
|
||||
})
|
||||
})
|
||||
)
|
||||
@@ -58,11 +64,15 @@ vi.mock('@/platform/workflow/core/services/workflowService', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const templateLoaderMocks = vi.hoisted(() => ({
|
||||
load: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/platform/workflow/templates/composables/useTemplateUrlLoader',
|
||||
() => ({
|
||||
useTemplateUrlLoader: () => ({
|
||||
loadTemplateFromUrl: vi.fn()
|
||||
loadTemplateFromUrl: templateLoaderMocks.load
|
||||
})
|
||||
})
|
||||
)
|
||||
@@ -78,7 +88,8 @@ vi.mock('@/stores/commandStore', () => ({
|
||||
}))
|
||||
|
||||
const routeMocks = vi.hoisted(() => ({
|
||||
query: {} as Record<string, unknown>
|
||||
query: {} as Record<string, unknown>,
|
||||
replace: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('vue-router', () => ({
|
||||
@@ -88,7 +99,7 @@ vi.mock('vue-router', () => ({
|
||||
}
|
||||
}),
|
||||
useRouter: () => ({
|
||||
replace: vi.fn()
|
||||
replace: routeMocks.replace
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -203,8 +214,12 @@ describe('useWorkflowPersistenceV2', () => {
|
||||
mocks.apiMock.removeEventListener.mockImplementation(() => {})
|
||||
openWorkflowMock.mockReset()
|
||||
loadBlankWorkflowMock.mockReset()
|
||||
sharedWorkflowLoaderMocks.load.mockReset()
|
||||
sharedWorkflowLoaderMocks.load.mockResolvedValue('not-present')
|
||||
templateLoaderMocks.load.mockReset()
|
||||
commandStoreMocks.execute.mockReset()
|
||||
routeMocks.query = {}
|
||||
routeMocks.replace.mockReset()
|
||||
preservedQueryMocks.payloads = {}
|
||||
})
|
||||
|
||||
@@ -283,6 +298,26 @@ describe('useWorkflowPersistenceV2', () => {
|
||||
return { promise, resolve }
|
||||
}
|
||||
|
||||
describe('migration', () => {
|
||||
it('falls back to initialClientId when clientId is unavailable', () => {
|
||||
mocks.apiMock.clientId = undefined as unknown as string
|
||||
mocks.apiMock.initialClientId = 'initial-client'
|
||||
|
||||
mountWorkflowPersistence()
|
||||
|
||||
expect(migrateV1toV2).toHaveBeenCalledWith(undefined, 'initial-client')
|
||||
})
|
||||
|
||||
it('passes undefined when no API client id is available', () => {
|
||||
mocks.apiMock.clientId = undefined as unknown as string
|
||||
mocks.apiMock.initialClientId = undefined as unknown as string
|
||||
|
||||
mountWorkflowPersistence()
|
||||
|
||||
expect(migrateV1toV2).toHaveBeenCalledWith(undefined, undefined)
|
||||
})
|
||||
})
|
||||
|
||||
describe('persistence toggle', () => {
|
||||
it('resets the V2 draft store only after workflow persistence is disabled', async () => {
|
||||
const draftStore = useWorkflowDraftStoreV2()
|
||||
@@ -298,6 +333,83 @@ describe('useWorkflowPersistenceV2', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('graph change persistence', () => {
|
||||
it('saves the active workflow draft after graphChanged debounce', async () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const draftStore = useWorkflowDraftStoreV2()
|
||||
const workflow = workflowStore.createTemporary('ActiveWorkflow.json')
|
||||
await workflowStore.openWorkflow(workflow)
|
||||
mocks.state.currentGraph = { nodes: [{ id: 1 }] }
|
||||
|
||||
mountWorkflowPersistence()
|
||||
mocks.state.graphChangedHandler?.()
|
||||
vi.advanceTimersByTime(PERSIST_DEBOUNCE_MS)
|
||||
|
||||
const draft = draftStore.getDraft(workflow.path)
|
||||
expect(draft?.data).toBe(JSON.stringify(mocks.state.currentGraph))
|
||||
expect(draft?.name).toBe(workflow.key)
|
||||
})
|
||||
|
||||
it('shows a toast when saving the active workflow draft fails', async () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const draftStore = useWorkflowDraftStoreV2()
|
||||
const workflow = workflowStore.createTemporary('FailingWorkflow.json')
|
||||
await workflowStore.openWorkflow(workflow)
|
||||
vi.spyOn(draftStore, 'saveDraft').mockReturnValue(false)
|
||||
|
||||
mountWorkflowPersistence()
|
||||
mocks.state.graphChangedHandler?.()
|
||||
vi.advanceTimersByTime(PERSIST_DEBOUNCE_MS)
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'error',
|
||||
detail: 'toastMessages.failedToSaveDraft'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('url workflow loaders', () => {
|
||||
it('loads a template from the current route query', async () => {
|
||||
routeMocks.query = { template: 'template-id' }
|
||||
const { loadTemplateFromUrlIfPresent } = mountWorkflowPersistence()
|
||||
|
||||
await loadTemplateFromUrlIfPresent()
|
||||
|
||||
expect(routeMocks.replace).not.toHaveBeenCalled()
|
||||
expect(templateLoaderMocks.load).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('hydrates preserved template intent back into the route before loading', async () => {
|
||||
preservedQueryMocks.payloads.template = { template: 'template-id' }
|
||||
const { loadTemplateFromUrlIfPresent } = mountWorkflowPersistence()
|
||||
|
||||
await loadTemplateFromUrlIfPresent()
|
||||
|
||||
expect(routeMocks.replace).toHaveBeenCalledWith({
|
||||
query: { template: 'template-id' }
|
||||
})
|
||||
expect(templateLoaderMocks.load).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('does not load a template when no template intent is present', async () => {
|
||||
const { loadTemplateFromUrlIfPresent } = mountWorkflowPersistence()
|
||||
|
||||
await loadTemplateFromUrlIfPresent()
|
||||
|
||||
expect(routeMocks.replace).not.toHaveBeenCalled()
|
||||
expect(templateLoaderMocks.load).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns the shared workflow loader result', async () => {
|
||||
sharedWorkflowLoaderMocks.load.mockResolvedValueOnce('loaded')
|
||||
const { loadSharedWorkflowFromUrlIfPresent } = mountWorkflowPersistence()
|
||||
|
||||
await expect(loadSharedWorkflowFromUrlIfPresent()).resolves.toBe('loaded')
|
||||
})
|
||||
})
|
||||
|
||||
describe('loadPreviousWorkflowFromStorage', () => {
|
||||
it('does not restore the active workflow early when open tab state exists', async () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
@@ -543,6 +655,49 @@ describe('useWorkflowPersistenceV2', () => {
|
||||
expect(workflowStore.openWorkflows.map((w) => w?.path)).toContain(path)
|
||||
})
|
||||
|
||||
it('recovers malformed temporary drafts with a default temporary workflow', async () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
vi.spyOn(workflowStore, 'loadWorkflows').mockResolvedValue()
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const draftStore = useWorkflowDraftStoreV2()
|
||||
const path = 'workflows/Broken.json'
|
||||
draftStore.saveDraft(path, '{bad json', {
|
||||
name: 'Broken.json',
|
||||
isTemporary: true
|
||||
})
|
||||
writeTabState([path], 0)
|
||||
|
||||
const { restoreWorkflowTabsState } = mountWorkflowPersistence()
|
||||
await restoreWorkflowTabsState()
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'Failed to parse workflow draft, creating with default',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(draftStore.getDraft(path)).toBeNull()
|
||||
expect(workflowStore.getWorkflowByPath(path)?.isTemporary).toBe(true)
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('does not recreate a missing saved workflow from a non-temporary draft', async () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
vi.spyOn(workflowStore, 'loadWorkflows').mockResolvedValue()
|
||||
const draftStore = useWorkflowDraftStoreV2()
|
||||
const path = 'workflows/Saved.json'
|
||||
draftStore.saveDraft(path, JSON.stringify({ title: 'saved' }), {
|
||||
name: 'Saved.json',
|
||||
isTemporary: false
|
||||
})
|
||||
writeTabState([path], 0)
|
||||
|
||||
const { restoreWorkflowTabsState } = mountWorkflowPersistence()
|
||||
await restoreWorkflowTabsState()
|
||||
|
||||
expect(workflowStore.getWorkflowByPath(path)).toBeNull()
|
||||
expect(openWorkflowMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips activation when persistence is disabled', async () => {
|
||||
settingMocks.persistRef!.value = false
|
||||
vi.spyOn(useWorkflowStore(), 'loadWorkflows').mockResolvedValue()
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
const mockApi = vi.hoisted(() => ({
|
||||
api: {
|
||||
clientId: 'test-client-id',
|
||||
initialClientId: 'test-client-id'
|
||||
clientId: 'test-client-id' as string | null,
|
||||
initialClientId: 'test-client-id' as string | null
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: mockApi.api
|
||||
}))
|
||||
|
||||
describe('useWorkflowTabState', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
mockApi.api.clientId = 'test-client-id'
|
||||
mockApi.api.initialClientId = 'test-client-id'
|
||||
sessionStorage.clear()
|
||||
})
|
||||
|
||||
@@ -29,6 +35,29 @@ describe('useWorkflowTabState', () => {
|
||||
expect(getActivePath()).toBe('workflows/test.json')
|
||||
})
|
||||
|
||||
it('falls back to initial client ID before client ID is set', async () => {
|
||||
mockApi.api.clientId = null
|
||||
mockApi.api.initialClientId = 'initial-client-id'
|
||||
const { useWorkflowTabState } = await import('./useWorkflowTabState')
|
||||
const { getActivePath, setActivePath } = useWorkflowTabState()
|
||||
|
||||
setActivePath('workflows/from-initial.json')
|
||||
|
||||
expect(getActivePath()).toBe('workflows/from-initial.json')
|
||||
})
|
||||
|
||||
it('does not read or write active path without any client ID', async () => {
|
||||
mockApi.api.clientId = null
|
||||
mockApi.api.initialClientId = null
|
||||
const { useWorkflowTabState } = await import('./useWorkflowTabState')
|
||||
const { getActivePath, setActivePath } = useWorkflowTabState()
|
||||
|
||||
setActivePath('workflows/ignored.json')
|
||||
|
||||
expect(getActivePath()).toBeNull()
|
||||
expect(sessionStorage.length).toBe(0)
|
||||
})
|
||||
|
||||
it('ignores pointer from different workspace', async () => {
|
||||
sessionStorage.setItem(
|
||||
'Comfy.Workspace.Current',
|
||||
@@ -73,6 +102,18 @@ describe('useWorkflowTabState', () => {
|
||||
expect(result!.activeIndex).toBe(1)
|
||||
})
|
||||
|
||||
it('does not read or write open paths without any client ID', async () => {
|
||||
mockApi.api.clientId = null
|
||||
mockApi.api.initialClientId = null
|
||||
const { useWorkflowTabState } = await import('./useWorkflowTabState')
|
||||
const { getOpenPaths, setOpenPaths } = useWorkflowTabState()
|
||||
|
||||
setOpenPaths(['workflows/ignored.json'], 0)
|
||||
|
||||
expect(getOpenPaths()).toBeNull()
|
||||
expect(sessionStorage.length).toBe(0)
|
||||
})
|
||||
|
||||
it('ignores pointer from different workspace', async () => {
|
||||
sessionStorage.setItem(
|
||||
'Comfy.Workspace.Current',
|
||||
|
||||
@@ -19,6 +19,7 @@ describe('migrateV1toV2', () => {
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
localStorage.clear()
|
||||
sessionStorage.clear()
|
||||
})
|
||||
@@ -77,6 +78,44 @@ describe('migrateV1toV2', () => {
|
||||
expect(index.order).toEqual([])
|
||||
})
|
||||
|
||||
it('creates empty V2 index when V1 draft JSON is invalid', () => {
|
||||
localStorage.setItem(`Comfy.Workflow.Drafts:${workspaceId}`, '{not-json')
|
||||
|
||||
expect(migrateV1toV2(workspaceId)).toBe(0)
|
||||
})
|
||||
|
||||
it('migrates zero drafts when V1 order is missing', () => {
|
||||
localStorage.setItem(
|
||||
`Comfy.Workflow.Drafts:${workspaceId}`,
|
||||
JSON.stringify({
|
||||
'workflows/a.json': {
|
||||
data: '{}',
|
||||
updatedAt: 1000,
|
||||
name: 'a',
|
||||
isTemporary: true
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
expect(migrateV1toV2(workspaceId)).toBe(0)
|
||||
})
|
||||
|
||||
it('skips paths that no longer exist in V1 drafts', () => {
|
||||
setV1Data(
|
||||
{
|
||||
'workflows/a.json': {
|
||||
data: '{}',
|
||||
updatedAt: 1000,
|
||||
name: 'a',
|
||||
isTemporary: true
|
||||
}
|
||||
},
|
||||
['workflows/a.json', 'workflows/missing.json']
|
||||
)
|
||||
|
||||
expect(migrateV1toV2(workspaceId)).toBe(1)
|
||||
})
|
||||
|
||||
it('migrates V1 drafts to V2 format', () => {
|
||||
const v1Drafts = {
|
||||
'workflows/a.json': {
|
||||
@@ -211,6 +250,14 @@ describe('migrateV1toV2', () => {
|
||||
localStorage.getItem(`Comfy.Workflow.DraftOrder:${workspaceId}`)
|
||||
).toBeNull()
|
||||
})
|
||||
|
||||
it('ignores storage errors during cleanup', () => {
|
||||
vi.spyOn(Storage.prototype, 'removeItem').mockImplementation(() => {
|
||||
throw new Error('blocked')
|
||||
})
|
||||
|
||||
expect(() => cleanupV1Data(workspaceId)).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('V1 tab state migration', () => {
|
||||
@@ -290,6 +337,71 @@ describe('migrateV1toV2', () => {
|
||||
// No tab state to migrate — should remain null
|
||||
expect(openPaths).toBeNull()
|
||||
})
|
||||
|
||||
it('clamps out-of-range V1 active tab index', () => {
|
||||
setV1Data(
|
||||
{
|
||||
'workflows/a.json': {
|
||||
data: '{}',
|
||||
updatedAt: 1000,
|
||||
name: 'a',
|
||||
isTemporary: true
|
||||
}
|
||||
},
|
||||
['workflows/a.json']
|
||||
)
|
||||
localStorage.setItem(
|
||||
'Comfy.OpenWorkflowsPaths',
|
||||
JSON.stringify(['workflows/a.json'])
|
||||
)
|
||||
localStorage.setItem('Comfy.ActiveWorkflowIndex', JSON.stringify(10))
|
||||
|
||||
migrateV1toV2(workspaceId, 'client-123')
|
||||
|
||||
expect(readOpenPaths('client-123', workspaceId)?.activeIndex).toBe(0)
|
||||
})
|
||||
|
||||
it('defaults V1 tab index when active index is invalid', () => {
|
||||
setV1Data(
|
||||
{
|
||||
'workflows/a.json': {
|
||||
data: '{}',
|
||||
updatedAt: 1000,
|
||||
name: 'a',
|
||||
isTemporary: true
|
||||
}
|
||||
},
|
||||
['workflows/a.json']
|
||||
)
|
||||
localStorage.setItem(
|
||||
'Comfy.OpenWorkflowsPaths',
|
||||
JSON.stringify(['workflows/a.json'])
|
||||
)
|
||||
localStorage.setItem('Comfy.ActiveWorkflowIndex', JSON.stringify('bad'))
|
||||
|
||||
migrateV1toV2(workspaceId, 'client-123')
|
||||
|
||||
expect(readOpenPaths('client-123', workspaceId)?.activeIndex).toBe(0)
|
||||
})
|
||||
|
||||
it('ignores invalid V1 tab state paths', () => {
|
||||
setV1Data(
|
||||
{
|
||||
'workflows/a.json': {
|
||||
data: '{}',
|
||||
updatedAt: 1000,
|
||||
name: 'a',
|
||||
isTemporary: true
|
||||
}
|
||||
},
|
||||
['workflows/a.json']
|
||||
)
|
||||
localStorage.setItem('Comfy.OpenWorkflowsPaths', JSON.stringify([]))
|
||||
|
||||
migrateV1toV2(workspaceId, 'client-123')
|
||||
|
||||
expect(readOpenPaths('client-123', workspaceId)).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getMigrationStatus', () => {
|
||||
|
||||
@@ -5,6 +5,8 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { MAX_DRAFTS } from '../base/draftTypes'
|
||||
import { StorageKeys } from '../base/storageKeys'
|
||||
import { useWorkflowDraftStoreV2 } from './workflowDraftStoreV2'
|
||||
import { WORKSPACE_STORAGE_KEYS } from '@/platform/workspace/workspaceConstants'
|
||||
import { app as comfyApp } from '@/scripts/app'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
@@ -195,6 +197,12 @@ describe('workflowDraftStoreV2', () => {
|
||||
)
|
||||
expect(payloadKeys).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('ignores missing drafts', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
expect(() => store.removeDraft('workflows/missing.json')).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('moveDraft', () => {
|
||||
@@ -240,6 +248,121 @@ describe('workflowDraftStoreV2', () => {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
|
||||
it('ignores missing source drafts', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
store.moveDraft('workflows/missing.json', 'workflows/new.json', 'new')
|
||||
|
||||
expect(store.getDraft('workflows/new.json')).toBeNull()
|
||||
})
|
||||
|
||||
it('does not move when the old payload is missing', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
store.saveDraft('workflows/old.json', '{"data":"test"}', {
|
||||
name: 'old',
|
||||
isTemporary: true
|
||||
})
|
||||
localStorage.removeItem(
|
||||
StorageKeys.draftPayload('workflows/old.json', 'personal')
|
||||
)
|
||||
store.reset()
|
||||
|
||||
store.moveDraft('workflows/old.json', 'workflows/new.json', 'new')
|
||||
|
||||
expect(store.getDraft('workflows/new.json')).toBeNull()
|
||||
})
|
||||
|
||||
it('keeps the original draft when writing the moved payload fails', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
store.saveDraft('workflows/old.json', '{"data":"test"}', {
|
||||
name: 'old',
|
||||
isTemporary: true
|
||||
})
|
||||
|
||||
const originalSetItem = localStorage.setItem.bind(localStorage)
|
||||
const newPayloadKey = StorageKeys.draftPayload(
|
||||
'workflows/new.json',
|
||||
'personal'
|
||||
)
|
||||
const setItemSpy = vi
|
||||
.spyOn(localStorage, 'setItem')
|
||||
.mockImplementation((key: string, value: string) => {
|
||||
if (key === newPayloadKey) {
|
||||
throw new DOMException('Quota exceeded', 'QuotaExceededError')
|
||||
}
|
||||
return originalSetItem(key, value)
|
||||
})
|
||||
|
||||
try {
|
||||
store.moveDraft('workflows/old.json', 'workflows/new.json', 'new')
|
||||
|
||||
expect(store.getDraft('workflows/old.json')).not.toBeNull()
|
||||
expect(store.getDraft('workflows/new.json')).toBeNull()
|
||||
} finally {
|
||||
setItemSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('removes the moved payload when persisting the moved index fails', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
store.saveDraft('workflows/old.json', '{"data":"test"}', {
|
||||
name: 'old',
|
||||
isTemporary: true
|
||||
})
|
||||
|
||||
const originalSetItem = localStorage.setItem.bind(localStorage)
|
||||
const indexKey = StorageKeys.draftIndex('personal')
|
||||
const newPayloadKey = StorageKeys.draftPayload(
|
||||
'workflows/new.json',
|
||||
'personal'
|
||||
)
|
||||
const setItemSpy = vi
|
||||
.spyOn(localStorage, 'setItem')
|
||||
.mockImplementation((key: string, value: string) => {
|
||||
if (key === indexKey) {
|
||||
throw new DOMException('Quota exceeded', 'QuotaExceededError')
|
||||
}
|
||||
return originalSetItem(key, value)
|
||||
})
|
||||
|
||||
try {
|
||||
store.moveDraft('workflows/old.json', 'workflows/new.json', 'new')
|
||||
|
||||
expect(localStorage.getItem(newPayloadKey)).toBeNull()
|
||||
} finally {
|
||||
setItemSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('getDraft', () => {
|
||||
it('removes stale index entries when the payload is missing', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
store.saveDraft('workflows/test.json', '{"nodes":[]}', {
|
||||
name: 'test',
|
||||
isTemporary: true
|
||||
})
|
||||
localStorage.removeItem(
|
||||
StorageKeys.draftPayload('workflows/test.json', 'personal')
|
||||
)
|
||||
store.reset()
|
||||
|
||||
expect(store.getDraft('workflows/test.json')).toBeNull()
|
||||
expect(store.getMostRecentPath()).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('markDraftUsed', () => {
|
||||
it('ignores unknown draft paths', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
expect(() => store.markDraftUsed('workflows/missing.json')).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getMostRecentPath', () => {
|
||||
@@ -262,6 +385,22 @@ describe('workflowDraftStoreV2', () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
expect(store.getMostRecentPath()).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null when the newest index key has no entry', () => {
|
||||
const indexKey = StorageKeys.draftIndex('personal')
|
||||
localStorage.setItem(
|
||||
indexKey,
|
||||
JSON.stringify({
|
||||
v: 2,
|
||||
updatedAt: Date.now(),
|
||||
order: ['missing'],
|
||||
entries: {}
|
||||
})
|
||||
)
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
expect(store.getMostRecentPath()).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('loadPersistedWorkflow', () => {
|
||||
@@ -308,6 +447,57 @@ describe('workflowDraftStoreV2', () => {
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('loads legacy session workflow payloads in personal workspace', async () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
sessionStorage.setItem('workflow:test-client', '{"nodes":[]}')
|
||||
|
||||
const result = await store.loadPersistedWorkflow({
|
||||
workflowName: 'legacy'
|
||||
})
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(comfyApp.loadGraphData).toHaveBeenCalledWith(
|
||||
{ nodes: [] },
|
||||
true,
|
||||
true,
|
||||
'legacy'
|
||||
)
|
||||
})
|
||||
|
||||
it('falls back to legacy local workflow payloads in personal workspace', async () => {
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
localStorage.setItem('workflow', '{"nodes":[1]}')
|
||||
|
||||
const result = await store.loadPersistedWorkflow({
|
||||
workflowName: null
|
||||
})
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(comfyApp.loadGraphData).toHaveBeenCalledWith(
|
||||
{ nodes: [1] },
|
||||
true,
|
||||
true,
|
||||
null
|
||||
)
|
||||
})
|
||||
|
||||
it('does not load legacy payloads for non-personal workspaces', async () => {
|
||||
sessionStorage.setItem(
|
||||
WORKSPACE_STORAGE_KEYS.CURRENT_WORKSPACE,
|
||||
JSON.stringify({ id: 'team-1', type: 'organization' })
|
||||
)
|
||||
sessionStorage.setItem('workflow:test-client', '{"nodes":[]}')
|
||||
localStorage.setItem('workflow', '{"nodes":[]}')
|
||||
const store = useWorkflowDraftStoreV2()
|
||||
|
||||
const result = await store.loadPersistedWorkflow({
|
||||
workflowName: 'team'
|
||||
})
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(comfyApp.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('reset', () => {
|
||||
|
||||
@@ -15,8 +15,10 @@ const mockWorkflowStore = reactive<{
|
||||
isModified: boolean
|
||||
lastModified: number
|
||||
} | null
|
||||
saveWorkflow: ReturnType<typeof vi.fn>
|
||||
}>({
|
||||
activeWorkflow: null
|
||||
activeWorkflow: null,
|
||||
saveWorkflow: vi.fn()
|
||||
})
|
||||
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
@@ -63,13 +65,18 @@ vi.mock(
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/platform/workflow/core/services/workflowService', () => ({
|
||||
useWorkflowService: () => ({
|
||||
saveWorkflow: vi.fn(),
|
||||
renameWorkflow: vi.fn()
|
||||
})
|
||||
const mockWorkflowService = vi.hoisted(() => ({
|
||||
saveWorkflow: vi.fn(),
|
||||
renameWorkflow: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/core/services/workflowService', () => ({
|
||||
useWorkflowService: () => mockWorkflowService
|
||||
}))
|
||||
|
||||
const mockInputFocus = vi.hoisted(() => vi.fn())
|
||||
const mockInputSelect = vi.hoisted(() => vi.fn())
|
||||
|
||||
const mockShareServiceData = vi.hoisted(() => ({
|
||||
items: [
|
||||
{
|
||||
@@ -113,6 +120,8 @@ const i18n = createI18n({
|
||||
g: { close: 'Close', error: 'Error' },
|
||||
shareWorkflow: {
|
||||
unsavedDescription: 'You must save your workflow before sharing.',
|
||||
saveFailedTitle: 'Save failed',
|
||||
saveFailedDescription: 'Unable to save workflow',
|
||||
shareLinkTab: 'Share',
|
||||
publishToHubTab: 'Publish',
|
||||
workflowNameLabel: 'Workflow name',
|
||||
@@ -138,6 +147,9 @@ const i18n = createI18n({
|
||||
introTitle: 'Introducing ComfyHub',
|
||||
createProfileButton: 'Create my profile',
|
||||
startPublishingButton: 'Start publishing'
|
||||
},
|
||||
comfyHubPublish: {
|
||||
unsavedDescription: 'Save before publishing.'
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -152,6 +164,9 @@ describe('ShareWorkflowDialogContent', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWorkflowStore.saveWorkflow.mockReset()
|
||||
mockWorkflowService.saveWorkflow.mockReset()
|
||||
mockWorkflowService.renameWorkflow.mockReset()
|
||||
mockPublishWorkflow.mockReset()
|
||||
mockGetShareableAssets.mockReset()
|
||||
mockWorkflowStore.activeWorkflow = {
|
||||
@@ -214,8 +229,14 @@ describe('ShareWorkflowDialogContent', () => {
|
||||
props: ['onCreateProfile']
|
||||
},
|
||||
Input: {
|
||||
template: '<input v-bind="$attrs" />',
|
||||
methods: { focus() {}, select() {} }
|
||||
template:
|
||||
'<input v-bind="$attrs" :value="modelValue" @input="$emit(\'update:modelValue\', $event.target.value)" />',
|
||||
props: ['modelValue'],
|
||||
emits: ['update:modelValue'],
|
||||
methods: {
|
||||
focus: mockInputFocus,
|
||||
select: mockInputSelect
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -556,5 +577,150 @@ describe('ShareWorkflowDialogContent', () => {
|
||||
expect(screen.queryByTestId('publish-tab-panel')).not.toBeInTheDocument()
|
||||
expect(container.textContent).not.toContain('Publish')
|
||||
})
|
||||
|
||||
it('focuses the temporary workflow name when switching to publish mode', async () => {
|
||||
mockFlags.comfyHubUploadEnabled = true
|
||||
mockWorkflowStore.activeWorkflow = {
|
||||
path: 'Unsaved Workflow.json',
|
||||
directory: '',
|
||||
filename: 'Unsaved Workflow.json',
|
||||
isTemporary: true,
|
||||
isModified: false,
|
||||
lastModified: 1000
|
||||
}
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
mockInputFocus.mockClear()
|
||||
mockInputSelect.mockClear()
|
||||
|
||||
await userEvent.click(screen.getByRole('tab', { name: /Publish/ }))
|
||||
await nextTick()
|
||||
|
||||
expect(mockInputFocus).toHaveBeenCalled()
|
||||
expect(mockInputSelect).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('renames and saves a temporary workflow from publish mode', async () => {
|
||||
mockFlags.comfyHubUploadEnabled = true
|
||||
const workflow = {
|
||||
path: 'Unsaved Workflow.json',
|
||||
directory: '',
|
||||
filename: 'Unsaved Workflow.json',
|
||||
isTemporary: true,
|
||||
isModified: false,
|
||||
lastModified: 1000
|
||||
}
|
||||
mockWorkflowStore.activeWorkflow = workflow
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
await userEvent.click(screen.getByRole('tab', { name: /Publish/ }))
|
||||
await userEvent.clear(screen.getByRole('textbox'))
|
||||
await userEvent.type(screen.getByRole('textbox'), ' Better name ')
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /Save workflow/ })
|
||||
)
|
||||
await flushPromises()
|
||||
|
||||
expect(mockWorkflowService.renameWorkflow).toHaveBeenCalledWith(
|
||||
workflow,
|
||||
'Better name.json'
|
||||
)
|
||||
expect(mockWorkflowStore.saveWorkflow).toHaveBeenCalledWith(workflow)
|
||||
})
|
||||
|
||||
it('does not save a temporary workflow with a blank name', async () => {
|
||||
mockWorkflowStore.activeWorkflow = {
|
||||
path: 'Unsaved Workflow.json',
|
||||
directory: 'workflows',
|
||||
filename: 'Unsaved Workflow.json',
|
||||
isTemporary: true,
|
||||
isModified: false,
|
||||
lastModified: 1000
|
||||
}
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
await userEvent.clear(screen.getByRole('textbox'))
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /Save workflow/ })
|
||||
)
|
||||
await flushPromises()
|
||||
|
||||
expect(mockWorkflowService.renameWorkflow).not.toHaveBeenCalled()
|
||||
expect(mockWorkflowStore.saveWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('saves a modified persisted workflow without renaming it', async () => {
|
||||
const workflow = {
|
||||
path: 'workflows/test.json',
|
||||
directory: 'workflows',
|
||||
filename: 'test.json',
|
||||
isTemporary: false,
|
||||
isModified: true,
|
||||
lastModified: 1000
|
||||
}
|
||||
mockWorkflowStore.activeWorkflow = workflow
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /Save workflow/ })
|
||||
)
|
||||
await flushPromises()
|
||||
|
||||
expect(mockWorkflowService.saveWorkflow).toHaveBeenCalledWith(workflow)
|
||||
expect(mockWorkflowService.renameWorkflow).not.toHaveBeenCalled()
|
||||
expect(mockWorkflowStore.saveWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows an error toast when saving fails', async () => {
|
||||
mockWorkflowStore.activeWorkflow = {
|
||||
path: 'workflows/test.json',
|
||||
directory: 'workflows',
|
||||
filename: 'test.json',
|
||||
isTemporary: false,
|
||||
isModified: true,
|
||||
lastModified: 1000
|
||||
}
|
||||
mockWorkflowService.saveWorkflow.mockRejectedValue(new Error('disk full'))
|
||||
const errorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => undefined)
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /Save workflow/ })
|
||||
)
|
||||
await flushPromises()
|
||||
|
||||
expect(mockToast.add).toHaveBeenCalledWith({
|
||||
severity: 'error',
|
||||
summary: 'Save failed',
|
||||
detail: 'Unable to save workflow'
|
||||
})
|
||||
errorSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('uses the generic error copy for non-error publish failures', async () => {
|
||||
mockGetShareableAssets.mockResolvedValue([])
|
||||
mockPublishWorkflow.mockRejectedValue('offline')
|
||||
const errorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => undefined)
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: /Create link/ }))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockToast.add).toHaveBeenCalledWith({
|
||||
severity: 'error',
|
||||
summary: 'Error',
|
||||
detail: 'Error'
|
||||
})
|
||||
errorSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import ComfyHubCreateProfileForm from './ComfyHubCreateProfileForm.vue'
|
||||
import type { ComponentProps } from 'vue-component-type-helpers'
|
||||
|
||||
const mockCreateProfile = vi.hoisted(() => vi.fn())
|
||||
const mockToast = vi.hoisted(() => ({
|
||||
add: vi.fn()
|
||||
}))
|
||||
const mockIsFileTooLarge = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('primevue/usetoast', () => ({
|
||||
useToast: () => mockToast
|
||||
}))
|
||||
|
||||
vi.mock('@vueuse/core', async () => {
|
||||
const { computed } = await import('vue')
|
||||
return {
|
||||
useObjectUrl: (file: { value: File | null }) =>
|
||||
computed(() => (file.value ? `blob:${file.value.name}` : undefined))
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock(
|
||||
'@/platform/workflow/sharing/composables/useComfyHubProfileGate',
|
||||
() => ({
|
||||
useComfyHubProfileGate: () => ({
|
||||
createProfile: mockCreateProfile
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/platform/workflow/sharing/utils/validateFileSize', () => ({
|
||||
MAX_IMAGE_SIZE_MB: 10,
|
||||
isFileTooLarge: mockIsFileTooLarge
|
||||
}))
|
||||
|
||||
vi.mock('@/components/ui/button/Button.vue', () => ({
|
||||
default: {
|
||||
props: ['disabled', 'ariaLabel'],
|
||||
emits: ['click'],
|
||||
template: `
|
||||
<button
|
||||
type="button"
|
||||
:disabled="disabled"
|
||||
:aria-label="ariaLabel"
|
||||
@click="$emit('click')"
|
||||
>
|
||||
<slot />
|
||||
</button>
|
||||
`
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/components/ui/input/Input.vue', () => ({
|
||||
default: {
|
||||
props: ['modelValue', 'id', 'placeholder'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<input
|
||||
:id="id"
|
||||
:value="modelValue"
|
||||
:placeholder="placeholder"
|
||||
v-bind="$attrs"
|
||||
@input="$emit('update:modelValue', $event.target.value)"
|
||||
/>
|
||||
`
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/components/ui/textarea/Textarea.vue', () => ({
|
||||
default: {
|
||||
props: ['modelValue', 'id', 'placeholder'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<textarea
|
||||
:id="id"
|
||||
:value="modelValue"
|
||||
:placeholder="placeholder"
|
||||
v-bind="$attrs"
|
||||
@input="$emit('update:modelValue', $event.target.value)"
|
||||
/>
|
||||
`
|
||||
}
|
||||
}))
|
||||
|
||||
function renderForm(
|
||||
props: Partial<ComponentProps<typeof ComfyHubCreateProfileForm>> = {}
|
||||
) {
|
||||
return render(ComfyHubCreateProfileForm, {
|
||||
props: {
|
||||
onProfileCreated: vi.fn(),
|
||||
onClose: vi.fn(),
|
||||
...props
|
||||
},
|
||||
global: {
|
||||
mocks: { $t: (key: string) => key }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function profileFile(name = 'avatar.png') {
|
||||
return new File(['image'], name, { type: 'image/png' })
|
||||
}
|
||||
|
||||
async function flushPromises() {
|
||||
await Promise.resolve()
|
||||
await nextTick()
|
||||
}
|
||||
|
||||
describe('ComfyHubCreateProfileForm', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCreateProfile.mockResolvedValue({
|
||||
username: 'valid-user',
|
||||
name: 'Valid User'
|
||||
})
|
||||
mockIsFileTooLarge.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('renders close and cancel actions and can hide the close header button', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onClose = vi.fn()
|
||||
const { unmount } = renderForm({ onClose })
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'g.close' }))
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'g.cancel' }))
|
||||
expect(onClose).toHaveBeenCalledTimes(2)
|
||||
|
||||
unmount()
|
||||
renderForm({ onClose, showCloseButton: false })
|
||||
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'g.close' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(screen.getByText('C')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('validates usernames and derives the profile initial from name or username', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderForm()
|
||||
|
||||
expect(screen.getByText('C')).toBeInTheDocument()
|
||||
|
||||
await user.type(screen.getByLabelText('comfyHubProfile.nameLabel'), 'Ada')
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
|
||||
await user.clear(screen.getByLabelText('comfyHubProfile.nameLabel'))
|
||||
await user.type(
|
||||
screen.getByLabelText('comfyHubProfile.usernameLabel'),
|
||||
'bad_name'
|
||||
)
|
||||
|
||||
expect(screen.getByText('B')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubProfile.usernameError')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByRole('button', { name: 'comfyHubProfile.createProfile' })
|
||||
).toBeDisabled()
|
||||
})
|
||||
|
||||
it('ignores oversized images and previews an accepted profile image', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderForm()
|
||||
|
||||
const input = screen.getByLabelText('comfyHubProfile.chooseProfilePicture')
|
||||
|
||||
mockIsFileTooLarge.mockReturnValueOnce(true)
|
||||
await user.upload(input, profileFile('large.png'))
|
||||
expect(
|
||||
screen.queryByAltText('comfyHubProfile.chooseProfilePicture')
|
||||
).not.toBeInTheDocument()
|
||||
|
||||
const acceptedFile = profileFile()
|
||||
await user.upload(input, acceptedFile)
|
||||
|
||||
expect(mockIsFileTooLarge).toHaveBeenLastCalledWith(acceptedFile, 10)
|
||||
expect(
|
||||
screen.getByAltText('comfyHubProfile.chooseProfilePicture')
|
||||
).toHaveAttribute('src', 'blob:avatar.png')
|
||||
})
|
||||
|
||||
it('creates a trimmed profile and reports it to the parent', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onProfileCreated = vi.fn()
|
||||
renderForm({ onProfileCreated })
|
||||
|
||||
const file = profileFile()
|
||||
const input = screen.getByLabelText('comfyHubProfile.chooseProfilePicture')
|
||||
|
||||
await user.upload(input, file)
|
||||
await user.type(
|
||||
screen.getByLabelText('comfyHubProfile.usernameLabel'),
|
||||
'valid-user'
|
||||
)
|
||||
await user.type(
|
||||
screen.getByLabelText('comfyHubProfile.nameLabel'),
|
||||
' Ada Lovelace '
|
||||
)
|
||||
await user.type(
|
||||
screen.getByLabelText('comfyHubProfile.descriptionLabel'),
|
||||
' '
|
||||
)
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'comfyHubProfile.createProfile' })
|
||||
)
|
||||
|
||||
expect(mockCreateProfile).toHaveBeenCalledWith({
|
||||
username: 'valid-user',
|
||||
name: 'Ada Lovelace',
|
||||
description: undefined,
|
||||
profilePicture: file
|
||||
})
|
||||
expect(onProfileCreated).toHaveBeenCalledWith({
|
||||
username: 'valid-user',
|
||||
name: 'Valid User'
|
||||
})
|
||||
})
|
||||
|
||||
it('shows loading text while creating and surfaces creation errors', async () => {
|
||||
const user = userEvent.setup()
|
||||
let resolveCreate: (profile: { username: string }) => void
|
||||
mockCreateProfile.mockReturnValueOnce(
|
||||
new Promise<{ username: string }>((resolve) => {
|
||||
resolveCreate = resolve
|
||||
})
|
||||
)
|
||||
renderForm()
|
||||
|
||||
await user.type(
|
||||
screen.getByLabelText('comfyHubProfile.usernameLabel'),
|
||||
'valid-user'
|
||||
)
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'comfyHubProfile.createProfile' })
|
||||
)
|
||||
await nextTick()
|
||||
|
||||
expect(
|
||||
screen.getByRole('button', { name: 'comfyHubProfile.creatingProfile' })
|
||||
).toBeDisabled()
|
||||
|
||||
resolveCreate!({ username: 'valid-user' })
|
||||
await flushPromises()
|
||||
|
||||
mockCreateProfile.mockRejectedValueOnce(new Error('already taken'))
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'comfyHubProfile.createProfile' })
|
||||
)
|
||||
await nextTick()
|
||||
|
||||
expect(mockToast.add).toHaveBeenCalledWith({
|
||||
severity: 'error',
|
||||
summary: 'g.error',
|
||||
detail: 'already taken'
|
||||
})
|
||||
|
||||
mockCreateProfile.mockRejectedValueOnce('unknown')
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'comfyHubProfile.createProfile' })
|
||||
)
|
||||
await nextTick()
|
||||
|
||||
expect(mockToast.add).toHaveBeenLastCalledWith({
|
||||
severity: 'error',
|
||||
summary: 'g.error',
|
||||
detail: 'g.error'
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,75 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import ComfyHubPublishIntroPanel from './ComfyHubPublishIntroPanel.vue'
|
||||
|
||||
vi.mock('@/components/ui/button/Button.vue', () => ({
|
||||
default: {
|
||||
props: ['ariaLabel'],
|
||||
emits: ['click'],
|
||||
template: `
|
||||
<button type="button" :aria-label="ariaLabel" @click="$emit('click')">
|
||||
<slot />
|
||||
</button>
|
||||
`
|
||||
}
|
||||
}))
|
||||
|
||||
function renderPanel(
|
||||
props: Partial<InstanceType<typeof ComfyHubPublishIntroPanel>['$props']> = {}
|
||||
) {
|
||||
return render(ComfyHubPublishIntroPanel, {
|
||||
props: {
|
||||
onCreateProfile: vi.fn(),
|
||||
onClose: vi.fn(),
|
||||
...props
|
||||
},
|
||||
global: {
|
||||
mocks: { $t: (key: string) => key }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('ComfyHubPublishIntroPanel', () => {
|
||||
it('renders the publish intro and handles close and create actions', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onClose = vi.fn()
|
||||
const onCreateProfile = vi.fn()
|
||||
renderPanel({ onClose, onCreateProfile })
|
||||
|
||||
expect(screen.getByText('comfyHubProfile.introTitle')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubProfile.introDescription')
|
||||
).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'g.close' }))
|
||||
await user.click(
|
||||
screen.getByRole('button', {
|
||||
name: 'comfyHubProfile.startPublishingButton'
|
||||
})
|
||||
)
|
||||
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
expect(onCreateProfile).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('renders the update variant without the close button', () => {
|
||||
renderPanel({ showCloseButton: false, isUpdate: true })
|
||||
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'g.close' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubProfile.updateIntroTitle')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubProfile.updateIntroDescription')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByRole('button', {
|
||||
name: 'comfyHubProfile.startUpdatingButton'
|
||||
})
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,15 +1,48 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ExampleImage } from '@/platform/workflow/sharing/types/comfyHubTypes'
|
||||
import { MAX_IMAGE_SIZE_MB } from '@/platform/workflow/sharing/utils/validateFileSize'
|
||||
|
||||
import ComfyHubExamplesStep from './ComfyHubExamplesStep.vue'
|
||||
|
||||
type DragData = Record<string, unknown>
|
||||
|
||||
type DraggableOptions = {
|
||||
getInitialData?: () => DragData
|
||||
}
|
||||
|
||||
type MonitorOptions = {
|
||||
canMonitor: (args: { source: { data: DragData } }) => boolean
|
||||
onDrop: (args: {
|
||||
source: { data: DragData }
|
||||
location: {
|
||||
current: {
|
||||
dropTargets: Array<{ data: DragData }>
|
||||
}
|
||||
}
|
||||
}) => void
|
||||
}
|
||||
|
||||
const pragmatic = vi.hoisted(() => ({
|
||||
draggables: [] as DraggableOptions[],
|
||||
monitor: undefined as MonitorOptions | undefined,
|
||||
cleanupMonitor: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@atlaskit/pragmatic-drag-and-drop/element/adapter', () => ({
|
||||
draggable: vi.fn(() => vi.fn()),
|
||||
draggable: vi.fn((options: DraggableOptions) => {
|
||||
pragmatic.draggables.push(options)
|
||||
return vi.fn()
|
||||
}),
|
||||
dropTargetForElements: vi.fn(() => vi.fn()),
|
||||
monitorForElements: vi.fn(() => vi.fn())
|
||||
monitorForElements: vi.fn((options: MonitorOptions) => {
|
||||
pragmatic.monitor = options
|
||||
return pragmatic.cleanupMonitor
|
||||
})
|
||||
}))
|
||||
|
||||
function createImages(count: number): ExampleImage[] {
|
||||
@@ -19,6 +52,14 @@ function createImages(count: number): ExampleImage[] {
|
||||
}))
|
||||
}
|
||||
|
||||
function createImageFile(name: string, size = 1): File {
|
||||
return new File([new Uint8Array(size)], name, { type: 'image/png' })
|
||||
}
|
||||
|
||||
function createTextFile(name: string): File {
|
||||
return new File(['text'], name, { type: 'text/plain' })
|
||||
}
|
||||
|
||||
function renderStep(
|
||||
images: ExampleImage[],
|
||||
callbacks: Record<string, ReturnType<typeof vi.fn>> = {}
|
||||
@@ -31,9 +72,33 @@ function renderStep(
|
||||
})
|
||||
}
|
||||
|
||||
function getUploadInput() {
|
||||
const labelContent = screen.getByText('comfyHubPublish.uploadExampleImage')
|
||||
// eslint-disable-next-line testing-library/no-node-access
|
||||
const label = labelContent.closest('label')
|
||||
// eslint-disable-next-line testing-library/no-node-access
|
||||
const input = label?.querySelector('input[type="file"]')
|
||||
if (!(input instanceof HTMLInputElement)) {
|
||||
throw new Error('Missing file input')
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
describe('ComfyHubExamplesStep', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia())
|
||||
vi.clearAllMocks()
|
||||
pragmatic.draggables = []
|
||||
pragmatic.monitor = undefined
|
||||
pragmatic.cleanupMonitor.mockClear()
|
||||
Object.defineProperty(URL, 'createObjectURL', {
|
||||
configurable: true,
|
||||
value: vi.fn((file: File) => `blob:${file.name}`)
|
||||
})
|
||||
Object.defineProperty(URL, 'revokeObjectURL', {
|
||||
configurable: true,
|
||||
value: vi.fn()
|
||||
})
|
||||
})
|
||||
|
||||
it('renders all example images', () => {
|
||||
@@ -111,4 +176,209 @@ describe('ComfyHubExamplesStep', () => {
|
||||
expect(onUpdateExampleImages).toHaveBeenCalled()
|
||||
expect(onUpdateExampleImages.mock.calls[0][0]).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('hides the upload tile when the example limit is reached', () => {
|
||||
renderStep(createImages(8))
|
||||
|
||||
expect(
|
||||
screen.queryByRole('button', {
|
||||
name: 'comfyHubPublish.uploadExampleImage'
|
||||
})
|
||||
).toBeNull()
|
||||
})
|
||||
|
||||
it('prepends selected image files and filters invalid uploads', async () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
renderStep(createImages(1), {
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
})
|
||||
|
||||
await userEvent.upload(getUploadInput(), [
|
||||
createImageFile('valid.png'),
|
||||
createTextFile('notes.txt'),
|
||||
createImageFile('too-large.png', MAX_IMAGE_SIZE_MB * 1024 * 1024 + 1)
|
||||
])
|
||||
|
||||
const updated = onUpdateExampleImages.mock.calls[0][0] as ExampleImage[]
|
||||
expect(updated.map((image) => image.url)).toEqual([
|
||||
'blob:valid.png',
|
||||
'blob:http://localhost/img-0'
|
||||
])
|
||||
})
|
||||
|
||||
it('revokes overflow uploads when only one example slot remains', async () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
renderStep(createImages(7), {
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
})
|
||||
|
||||
await fireEvent.drop(
|
||||
screen.getByRole('button', {
|
||||
name: 'comfyHubPublish.uploadExampleImage'
|
||||
}),
|
||||
{
|
||||
dataTransfer: {
|
||||
files: [createImageFile('first.png'), createImageFile('second.png')]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
const updated = onUpdateExampleImages.mock.calls[0][0] as ExampleImage[]
|
||||
expect(updated).toHaveLength(8)
|
||||
expect(updated[0].url).toBe('blob:first.png')
|
||||
expect(URL.revokeObjectURL).toHaveBeenCalledWith('blob:second.png')
|
||||
})
|
||||
|
||||
it('revokes object URLs when removing uploaded images', async () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
const uploaded = createImageFile('uploaded.png')
|
||||
renderStep(
|
||||
[
|
||||
{
|
||||
id: 'uploaded',
|
||||
url: 'blob:uploaded.png',
|
||||
file: uploaded
|
||||
}
|
||||
],
|
||||
{
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
}
|
||||
)
|
||||
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', {
|
||||
name: 'comfyHubPublish.removeExampleImage'
|
||||
})
|
||||
)
|
||||
|
||||
expect(URL.revokeObjectURL).toHaveBeenCalledWith('blob:uploaded.png')
|
||||
expect(onUpdateExampleImages.mock.calls[0][0]).toEqual([])
|
||||
})
|
||||
|
||||
it('monitors drags from its own image grid only', () => {
|
||||
renderStep(createImages(1))
|
||||
const monitor = pragmatic.monitor
|
||||
const dragData = pragmatic.draggables[0]?.getInitialData?.()
|
||||
if (!monitor || !dragData) {
|
||||
throw new Error('Missing drag monitor setup')
|
||||
}
|
||||
|
||||
expect(monitor.canMonitor({ source: { data: dragData } })).toBe(true)
|
||||
expect(
|
||||
monitor.canMonitor({
|
||||
source: {
|
||||
data: {
|
||||
...dragData,
|
||||
instanceId: Symbol('other-grid')
|
||||
}
|
||||
}
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('reorders images when the drag monitor drops on another image', () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
renderStep(createImages(3), {
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
})
|
||||
const monitor = pragmatic.monitor
|
||||
const dragData = pragmatic.draggables[0]?.getInitialData?.()
|
||||
if (!monitor || !dragData) {
|
||||
throw new Error('Missing drag monitor setup')
|
||||
}
|
||||
|
||||
monitor.onDrop({
|
||||
source: { data: dragData },
|
||||
location: {
|
||||
current: {
|
||||
dropTargets: [{ data: { imageId: 'img-2' } }]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const reordered = onUpdateExampleImages.mock.calls[0][0] as ExampleImage[]
|
||||
expect(reordered.map((img) => img.id)).toEqual(['img-1', 'img-2', 'img-0'])
|
||||
})
|
||||
|
||||
it('ignores monitor drops without a destination image', () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
renderStep(createImages(2), {
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
})
|
||||
const monitor = pragmatic.monitor
|
||||
const dragData = pragmatic.draggables[0]?.getInitialData?.()
|
||||
if (!monitor || !dragData) {
|
||||
throw new Error('Missing drag monitor setup')
|
||||
}
|
||||
|
||||
monitor.onDrop({
|
||||
source: { data: dragData },
|
||||
location: { current: { dropTargets: [] } }
|
||||
})
|
||||
monitor.onDrop({
|
||||
source: { data: { ...dragData, imageId: 1 } },
|
||||
location: {
|
||||
current: {
|
||||
dropTargets: [{ data: { imageId: 'img-1' } }]
|
||||
}
|
||||
}
|
||||
})
|
||||
monitor.onDrop({
|
||||
source: { data: dragData },
|
||||
location: {
|
||||
current: {
|
||||
dropTargets: [{ data: { imageId: 1 } }]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(onUpdateExampleImages).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('inserts files from an image tile drop', async () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
renderStep(createImages(2), {
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
})
|
||||
|
||||
await fireEvent.drop(screen.getAllByRole('listitem')[1], {
|
||||
dataTransfer: {
|
||||
files: [createImageFile('inserted.png')]
|
||||
}
|
||||
})
|
||||
|
||||
const updated = onUpdateExampleImages.mock.calls[0][0] as ExampleImage[]
|
||||
expect(updated.map((image) => image.url)).toEqual([
|
||||
'blob:http://localhost/img-0',
|
||||
'blob:inserted.png',
|
||||
'blob:http://localhost/img-1'
|
||||
])
|
||||
})
|
||||
|
||||
it('replaces existing images when inserting into a full grid', async () => {
|
||||
const onUpdateExampleImages = vi.fn()
|
||||
const original = createImages(8).map((image, index) => ({
|
||||
...image,
|
||||
file: index === 1 ? createImageFile('old.png') : undefined
|
||||
}))
|
||||
renderStep(original, {
|
||||
'onUpdate:exampleImages': onUpdateExampleImages
|
||||
})
|
||||
|
||||
await fireEvent.drop(screen.getAllByRole('listitem')[1], {
|
||||
dataTransfer: {
|
||||
files: [createImageFile('replacement.png')]
|
||||
}
|
||||
})
|
||||
|
||||
const updated = onUpdateExampleImages.mock.calls[0][0] as ExampleImage[]
|
||||
expect(updated.map((image) => image.url).slice(0, 3)).toEqual([
|
||||
'blob:http://localhost/img-0',
|
||||
'blob:replacement.png',
|
||||
'blob:http://localhost/img-2'
|
||||
])
|
||||
expect(URL.revokeObjectURL).toHaveBeenCalledWith(
|
||||
'blob:http://localhost/img-1'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import type { AssetInfo, ComfyHubProfile } from '@/schemas/apiSchema'
|
||||
|
||||
import ComfyHubFinishStep from './ComfyHubFinishStep.vue'
|
||||
|
||||
const mockAsyncState = vi.hoisted(() => ({
|
||||
refs: null as null | {
|
||||
state: { value: AssetInfo[] }
|
||||
isLoading: { value: boolean }
|
||||
error: { value: Error | null }
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@vueuse/core', async () => {
|
||||
const { ref } = await import('vue')
|
||||
|
||||
const state = ref<AssetInfo[]>([])
|
||||
const isLoading = ref(false)
|
||||
const error = ref<Error | null>(null)
|
||||
|
||||
mockAsyncState.refs = {
|
||||
state,
|
||||
isLoading,
|
||||
error
|
||||
}
|
||||
|
||||
return {
|
||||
useAsyncState: () => ({
|
||||
state,
|
||||
isLoading,
|
||||
error
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/platform/workflow/sharing/services/workflowShareService', () => ({
|
||||
useWorkflowShareService: () => ({
|
||||
getShareableAssets: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/platform/workflow/sharing/components/ShareAssetWarningBox.vue',
|
||||
() => ({
|
||||
default: {
|
||||
props: ['items', 'acknowledged'],
|
||||
emits: ['update:acknowledged'],
|
||||
template: `
|
||||
<section data-testid="asset-warning">
|
||||
<span v-for="item in items" :key="item.id">{{ item.name }}</span>
|
||||
<button type="button" @click="$emit('update:acknowledged', true)">
|
||||
acknowledge
|
||||
</button>
|
||||
</section>
|
||||
`
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
const profile: ComfyHubProfile = {
|
||||
username: 'ada',
|
||||
name: 'Ada Lovelace',
|
||||
description: 'First programmer'
|
||||
}
|
||||
|
||||
function setAsyncState({
|
||||
assets = [],
|
||||
loading = false,
|
||||
error = null
|
||||
}: {
|
||||
assets?: AssetInfo[]
|
||||
loading?: boolean
|
||||
error?: Error | null
|
||||
} = {}) {
|
||||
if (!mockAsyncState.refs)
|
||||
throw new Error('async state refs were not initialized')
|
||||
mockAsyncState.refs.state.value = assets
|
||||
mockAsyncState.refs.isLoading.value = loading
|
||||
mockAsyncState.refs.error.value = error
|
||||
}
|
||||
|
||||
function renderStep(
|
||||
props: Partial<InstanceType<typeof ComfyHubFinishStep>['$props']> = {}
|
||||
) {
|
||||
return render(ComfyHubFinishStep, {
|
||||
props: {
|
||||
profile,
|
||||
...props
|
||||
},
|
||||
global: {
|
||||
mocks: { $t: (key: string) => key }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('ComfyHubFinishStep', () => {
|
||||
beforeEach(() => {
|
||||
setAsyncState()
|
||||
})
|
||||
|
||||
it('renders profile pictures while assets are loading', () => {
|
||||
setAsyncState({ loading: true })
|
||||
|
||||
renderStep({
|
||||
profile: {
|
||||
...profile,
|
||||
profilePictureUrl: 'https://cdn.example.com/ada.png'
|
||||
}
|
||||
})
|
||||
|
||||
expect(screen.getByAltText('ada')).toHaveAttribute(
|
||||
'src',
|
||||
'https://cdn.example.com/ada.png'
|
||||
)
|
||||
expect(screen.getByText('shareWorkflow.checkingAssets')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('requires acknowledging private assets before it becomes ready', async () => {
|
||||
const user = userEvent.setup()
|
||||
setAsyncState({
|
||||
assets: [
|
||||
{
|
||||
id: 'asset-1',
|
||||
name: 'private.png',
|
||||
preview_url: '',
|
||||
storage_url: '',
|
||||
model: false,
|
||||
public: false,
|
||||
in_library: false
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
renderStep()
|
||||
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('asset-warning')).toHaveTextContent('private.png')
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'acknowledge' }))
|
||||
await nextTick()
|
||||
|
||||
expect(screen.getByTestId('asset-warning')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('is ready when no assets are private', () => {
|
||||
renderStep()
|
||||
|
||||
expect(
|
||||
screen.queryByText('comfyHubPublish.additionalInfo')
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('is not ready when asset loading fails', () => {
|
||||
setAsyncState({ error: new Error('failed') })
|
||||
renderStep()
|
||||
|
||||
expect(
|
||||
screen.queryByText('shareWorkflow.checkingAssets')
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,33 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import ComfyHubProfilePromptPanel from './ComfyHubProfilePromptPanel.vue'
|
||||
|
||||
function renderProfilePrompt() {
|
||||
return render(ComfyHubProfilePromptPanel, {
|
||||
global: {
|
||||
mocks: { $t: (key: string) => key },
|
||||
stubs: {
|
||||
Button: {
|
||||
emits: ['click'],
|
||||
template: '<button @click="$emit(\'click\')"><slot /></button>'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('ComfyHubProfilePromptPanel', () => {
|
||||
it('emits a profile request from the create profile CTA', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { emitted } = renderProfilePrompt()
|
||||
|
||||
await user.click(screen.getByText('comfyHubPublish.createProfileCta'))
|
||||
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.createProfileToPublish')
|
||||
).toBeTruthy()
|
||||
expect(emitted().requestProfile).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,7 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, ref } from 'vue'
|
||||
import { defineComponent, h, nextTick, ref } from 'vue'
|
||||
|
||||
vi.mock('vue-i18n', async (importOriginal) => {
|
||||
const actual = await importOriginal()
|
||||
@@ -31,9 +31,19 @@ const mockGetCachedPrefill = vi.hoisted(() => vi.fn())
|
||||
const mockSubmitToComfyHub = vi.hoisted(() => vi.fn())
|
||||
const mockGetPublishStatus = vi.hoisted(() => vi.fn())
|
||||
const mockRenameWorkflow = vi.hoisted(() => vi.fn())
|
||||
const mockWorkflowServiceSaveWorkflow = vi.hoisted(() => vi.fn())
|
||||
const mockWorkflowStoreSaveWorkflow = vi.hoisted(() => vi.fn())
|
||||
const mockInputFocus = vi.hoisted(() => vi.fn())
|
||||
const mockInputSelect = vi.hoisted(() => vi.fn())
|
||||
const mockFormDataHolder = vi.hoisted(
|
||||
() => ({ value: null }) as { value: Record<string, unknown> | null }
|
||||
)
|
||||
const mockFormDataRefHolder = vi.hoisted(
|
||||
() =>
|
||||
({ value: null }) as {
|
||||
value: null | { value: Record<string, unknown> | null }
|
||||
}
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/platform/workflow/sharing/composables/useComfyHubProfileGate',
|
||||
@@ -64,10 +74,12 @@ vi.mock(
|
||||
tutorialUrl: '',
|
||||
metadata: {}
|
||||
}
|
||||
const formData = ref(mockFormDataHolder.value)
|
||||
mockFormDataRefHolder.value = formData
|
||||
return {
|
||||
useComfyHubPublishWizard: () => ({
|
||||
currentStep: ref('finish'),
|
||||
formData: ref(mockFormDataHolder.value),
|
||||
formData,
|
||||
isFirstStep: ref(false),
|
||||
isLastStep: ref(true),
|
||||
goToStep: mockGoToStep,
|
||||
@@ -101,7 +113,7 @@ vi.mock('@/platform/workflow/sharing/services/workflowShareService', () => ({
|
||||
vi.mock('@/platform/workflow/core/services/workflowService', () => ({
|
||||
useWorkflowService: () => ({
|
||||
renameWorkflow: mockRenameWorkflow,
|
||||
saveWorkflow: vi.fn()
|
||||
saveWorkflow: mockWorkflowServiceSaveWorkflow
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -128,7 +140,7 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', async () => {
|
||||
get activeWorkflow() {
|
||||
return mockWorkflowStore.instance?.activeWorkflow ?? null
|
||||
},
|
||||
saveWorkflow: vi.fn()
|
||||
saveWorkflow: mockWorkflowStoreSaveWorkflow
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -158,7 +170,15 @@ describe('ComfyHubPublishDialog', () => {
|
||||
mockFetchProfile.mockResolvedValue(null)
|
||||
mockSubmitToComfyHub.mockResolvedValue(undefined)
|
||||
mockRenameWorkflow.mockResolvedValue(undefined)
|
||||
mockWorkflowServiceSaveWorkflow.mockResolvedValue(undefined)
|
||||
mockWorkflowStoreSaveWorkflow.mockResolvedValue(undefined)
|
||||
mockInputFocus.mockClear()
|
||||
mockInputSelect.mockClear()
|
||||
if (mockFormDataHolder.value) mockFormDataHolder.value.name = ''
|
||||
if (mockFormDataHolder.value) mockFormDataHolder.value.exampleImages = []
|
||||
if (mockFormDataRefHolder.value) {
|
||||
mockFormDataRefHolder.value.value = mockFormDataHolder.value
|
||||
}
|
||||
mockGetCachedPrefill.mockReturnValue(null)
|
||||
mockGetPublishStatus.mockResolvedValue({
|
||||
isPublished: false,
|
||||
@@ -191,13 +211,15 @@ describe('ComfyHubPublishDialog', () => {
|
||||
},
|
||||
ComfyHubPublishWizardContent: {
|
||||
template:
|
||||
'<div :data-is-publishing="$props.isPublishing"><button data-testid="require-profile" @click="$props.onRequireProfile()" /><button data-testid="gate-complete" @click="$props.onGateComplete()" /><button data-testid="gate-close" @click="$props.onGateClose()" /><button data-testid="publish" @click="$props.onPublish()" /></div>',
|
||||
'<div data-testid="publish-wizard-content" :data-is-publishing="$props.isPublishing" :data-is-update="$props.isUpdate"><button data-testid="require-profile" @click="$props.onRequireProfile()" /><button data-testid="gate-complete" @click="$props.onGateComplete()" /><button data-testid="gate-close" @click="$props.onGateClose()" /><button data-testid="publish" @click="$props.onPublish()" /><button data-testid="patch-form" @click="$props.onUpdateFormData({ description: \'patched\' })" /></div>',
|
||||
props: [
|
||||
'currentStep',
|
||||
'formData',
|
||||
'isFirstStep',
|
||||
'isLastStep',
|
||||
'isPublishing',
|
||||
'isUpdate',
|
||||
'onUpdateFormData',
|
||||
'onGoNext',
|
||||
'onGoBack',
|
||||
'onPublish',
|
||||
@@ -205,7 +227,44 @@ describe('ComfyHubPublishDialog', () => {
|
||||
'onGateComplete',
|
||||
'onGateClose'
|
||||
]
|
||||
}
|
||||
},
|
||||
Button: {
|
||||
template:
|
||||
'<button data-testid="save-workflow" :data-loading="loading" @click="$emit(\'click\')"><slot /></button>',
|
||||
props: ['loading'],
|
||||
emits: ['click']
|
||||
},
|
||||
Input: defineComponent({
|
||||
props: {
|
||||
modelValue: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
disabled: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
emits: ['update:modelValue'],
|
||||
setup(props, { emit, expose }) {
|
||||
expose({
|
||||
focus: mockInputFocus,
|
||||
select: mockInputSelect
|
||||
})
|
||||
return () =>
|
||||
h('input', {
|
||||
'data-testid': 'workflow-name',
|
||||
disabled: props.disabled,
|
||||
value: props.modelValue,
|
||||
onInput: (event: Event) => {
|
||||
emit(
|
||||
'update:modelValue',
|
||||
(event.target as HTMLInputElement).value
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -484,6 +543,215 @@ describe('ComfyHubPublishDialog', () => {
|
||||
expect(mockGetPublishStatus).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows the save prompt and focuses the temporary workflow name', async () => {
|
||||
setActiveWorkflow({
|
||||
path: null,
|
||||
filename: 'draft.json',
|
||||
directory: '',
|
||||
isTemporary: true,
|
||||
isModified: false
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
expect(screen.getByTestId('publish-save-prompt')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('publish-nav')).not.toBeInTheDocument()
|
||||
expect(mockInputFocus).toHaveBeenCalledOnce()
|
||||
expect(mockInputSelect).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('renames and saves a temporary workflow before showing the wizard', async () => {
|
||||
const workflow = {
|
||||
path: null,
|
||||
filename: 'draft.json',
|
||||
directory: '',
|
||||
isTemporary: true,
|
||||
isModified: false
|
||||
}
|
||||
setActiveWorkflow(workflow)
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
await userEvent.clear(screen.getByTestId('workflow-name'))
|
||||
await userEvent.type(screen.getByTestId('workflow-name'), 'Saved Name.json')
|
||||
await userEvent.click(screen.getByTestId('save-workflow'))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockRenameWorkflow).toHaveBeenCalledWith(workflow, 'Saved Name.json')
|
||||
expect(mockWorkflowStoreSaveWorkflow).toHaveBeenCalledWith(workflow)
|
||||
})
|
||||
|
||||
it('does not save a temporary workflow with a blank name', async () => {
|
||||
setActiveWorkflow({
|
||||
path: null,
|
||||
filename: 'draft.json',
|
||||
directory: '',
|
||||
isTemporary: true,
|
||||
isModified: false
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
await userEvent.clear(screen.getByTestId('workflow-name'))
|
||||
await userEvent.click(screen.getByTestId('save-workflow'))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockRenameWorkflow).not.toHaveBeenCalled()
|
||||
expect(mockWorkflowStoreSaveWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('saves a modified workflow without renaming it', async () => {
|
||||
const workflow = {
|
||||
path: 'workflows/test.json',
|
||||
filename: 'test.json',
|
||||
directory: 'workflows',
|
||||
isTemporary: false,
|
||||
isModified: true
|
||||
}
|
||||
setActiveWorkflow(workflow)
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
await userEvent.click(screen.getByTestId('save-workflow'))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockWorkflowServiceSaveWorkflow).toHaveBeenCalledWith(workflow)
|
||||
expect(mockRenameWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows a save error toast when saving fails', async () => {
|
||||
setActiveWorkflow({
|
||||
path: 'workflows/test.json',
|
||||
filename: 'test.json',
|
||||
directory: 'workflows',
|
||||
isTemporary: false,
|
||||
isModified: true
|
||||
})
|
||||
mockWorkflowServiceSaveWorkflow.mockRejectedValueOnce(
|
||||
new Error('save failed')
|
||||
)
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
await userEvent.click(screen.getByTestId('save-workflow'))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not save when there is no active workflow', async () => {
|
||||
setActiveWorkflow(null)
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
await userEvent.click(screen.getByTestId('save-workflow'))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockWorkflowServiceSaveWorkflow).not.toHaveBeenCalled()
|
||||
expect(mockWorkflowStoreSaveWorkflow).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses cached prefill when an already published workflow has no server prefill', async () => {
|
||||
const cached = { description: 'cached published data' }
|
||||
mockGetCachedPrefill.mockReturnValue(cached)
|
||||
mockGetPublishStatus.mockResolvedValue({
|
||||
isPublished: true,
|
||||
shareId: 'abc123',
|
||||
shareUrl: 'http://localhost/?share=abc123',
|
||||
publishedAt: new Date(),
|
||||
prefill: null
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
|
||||
expect(mockApplyPrefill).toHaveBeenCalledWith(cached)
|
||||
expect(screen.getByTestId('publish-wizard-content')).toHaveAttribute(
|
||||
'data-is-update',
|
||||
'true'
|
||||
)
|
||||
})
|
||||
|
||||
it('ignores stale prefill errors after the workflow path changes', async () => {
|
||||
let rejectStale: (error: unknown) => void = () => {}
|
||||
mockGetPublishStatus.mockImplementation((path: string) => {
|
||||
if (path === 'workflows/test.json') {
|
||||
return new Promise((_resolve, reject) => {
|
||||
rejectStale = reject
|
||||
})
|
||||
}
|
||||
return Promise.resolve({
|
||||
isPublished: false,
|
||||
shareId: null,
|
||||
shareUrl: null,
|
||||
publishedAt: null,
|
||||
prefill: null
|
||||
})
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
await nextTick()
|
||||
setActiveWorkflow({
|
||||
path: 'workflows/renamed.json',
|
||||
filename: 'renamed.json',
|
||||
directory: 'workflows',
|
||||
isTemporary: false,
|
||||
isModified: false
|
||||
})
|
||||
await nextTick()
|
||||
await flushPromises()
|
||||
|
||||
rejectStale(new Error('stale failure'))
|
||||
await flushPromises()
|
||||
|
||||
expect(mockGetCachedPrefill).not.toHaveBeenCalledWith('workflows/test.json')
|
||||
})
|
||||
|
||||
it('updates form data patches from wizard content', async () => {
|
||||
renderComponent()
|
||||
await flushPromises()
|
||||
await userEvent.click(screen.getByTestId('patch-form'))
|
||||
|
||||
expect(mockFormDataRefHolder.value?.value).toMatchObject({
|
||||
description: 'patched'
|
||||
})
|
||||
})
|
||||
|
||||
it('revokes uploaded example image object URLs on unmount', async () => {
|
||||
const revokeObjectURL = vi.fn()
|
||||
vi.stubGlobal('URL', {
|
||||
...URL,
|
||||
revokeObjectURL
|
||||
})
|
||||
const file = new File(['image'], 'example.png', { type: 'image/png' })
|
||||
if (mockFormDataHolder.value) {
|
||||
mockFormDataHolder.value.exampleImages = [
|
||||
{
|
||||
id: 'uploaded',
|
||||
url: 'blob:uploaded',
|
||||
file
|
||||
},
|
||||
{
|
||||
id: 'remote',
|
||||
url: 'https://example.com/remote.png',
|
||||
file: null
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
const { unmount } = renderComponent()
|
||||
await flushPromises()
|
||||
unmount()
|
||||
|
||||
expect(revokeObjectURL).toHaveBeenCalledWith('blob:uploaded')
|
||||
expect(revokeObjectURL).not.toHaveBeenCalledWith(
|
||||
'https://example.com/remote.png'
|
||||
)
|
||||
})
|
||||
|
||||
it('ignores a stale prefill response after the workflow path changes', async () => {
|
||||
const stalePrefill = { description: 'stale' }
|
||||
let resolveStale: (value: unknown) => void = () => {}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import ComfyHubPublishFooter from './ComfyHubPublishFooter.vue'
|
||||
@@ -10,7 +11,10 @@ function renderFooter(props: Record<string, unknown> = {}) {
|
||||
mocks: { $t: (key: string) => key },
|
||||
stubs: {
|
||||
Button: {
|
||||
template: '<button><slot /></button>'
|
||||
props: ['disabled', 'loading'],
|
||||
emits: ['click'],
|
||||
template:
|
||||
'<button :disabled="disabled" @click="$emit(\'click\')"><slot /></button>'
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,4 +31,28 @@ describe('ComfyHubPublishFooter', () => {
|
||||
renderFooter({ isUpdate: true })
|
||||
expect(screen.getByText('comfyHubPublish.updateButton')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('shows only the next action on the first non-final step', () => {
|
||||
renderFooter({ isFirstStep: true, isLastStep: false })
|
||||
|
||||
expect(screen.queryByText('comfyHubPublish.back')).toBeNull()
|
||||
expect(screen.getByText('comfyHubPublish.next')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('emits back and next from middle steps', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { emitted } = renderFooter({ isFirstStep: false, isLastStep: false })
|
||||
|
||||
await user.click(screen.getByText('comfyHubPublish.back'))
|
||||
await user.click(screen.getByText('comfyHubPublish.next'))
|
||||
|
||||
expect(emitted().back).toHaveLength(1)
|
||||
expect(emitted().next).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('disables publish while publishing', () => {
|
||||
renderFooter({ isPublishDisabled: false, isPublishing: true })
|
||||
|
||||
expect(screen.getByText('comfyHubPublish.publishButton')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import { render, screen, within } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import ComfyHubPublishNav from './ComfyHubPublishNav.vue'
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@formkit/auto-animate/vue', () => ({
|
||||
vAutoAnimate: {}
|
||||
}))
|
||||
|
||||
vi.mock('@/components/ui/button/Button.vue', () => ({
|
||||
default: {
|
||||
emits: ['click'],
|
||||
template: `
|
||||
<button type="button" v-bind="$attrs" @click="$emit('click')">
|
||||
<slot />
|
||||
</button>
|
||||
`
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/components/common/StatusBadge.vue', () => ({
|
||||
default: {
|
||||
props: ['label'],
|
||||
template:
|
||||
'<span data-testid="step-badge" v-bind="$attrs">{{ label }}</span>'
|
||||
}
|
||||
}))
|
||||
|
||||
describe('ComfyHubPublishNav', () => {
|
||||
it('marks current and completed steps and emits clicked steps', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { emitted } = render(ComfyHubPublishNav, {
|
||||
props: {
|
||||
currentStep: 'examples'
|
||||
},
|
||||
global: {
|
||||
mocks: { $t: (key: string) => key }
|
||||
}
|
||||
})
|
||||
|
||||
const nav = screen.getByTestId('publish-nav')
|
||||
expect(
|
||||
within(nav).getByText('comfyHubPublish.stepExamples')
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByRole('listitem', { current: 'step' })).toHaveTextContent(
|
||||
'comfyHubPublish.stepExamples'
|
||||
)
|
||||
expect(screen.getByRole('listitem', { current: 'step' })).toHaveAttribute(
|
||||
'aria-current',
|
||||
'step'
|
||||
)
|
||||
expect(screen.getAllByTestId('step-badge')[0]).toHaveClass(
|
||||
'border-base-foreground'
|
||||
)
|
||||
expect(screen.getAllByTestId('step-badge')[2]).toHaveClass(
|
||||
'border-muted-foreground'
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('comfyHubPublish.stepFinish'))
|
||||
|
||||
expect(emitted('stepClick')).toEqual([['finish']])
|
||||
})
|
||||
|
||||
it('renders the profile creation sub-step as part of the finish step', () => {
|
||||
render(ComfyHubPublishNav, {
|
||||
props: {
|
||||
currentStep: 'profileCreation'
|
||||
},
|
||||
global: {
|
||||
mocks: { $t: (key: string) => key }
|
||||
}
|
||||
})
|
||||
|
||||
expect(
|
||||
screen.getByText('comfyHubProfile.profileCreationNav')
|
||||
).toBeInTheDocument()
|
||||
const finishStep = screen.getAllByRole('listitem')[2]
|
||||
expect(finishStep).toHaveClass('bg-secondary-background-hover')
|
||||
expect(finishStep).not.toHaveAttribute('aria-current')
|
||||
})
|
||||
})
|
||||
@@ -120,7 +120,7 @@ describe('ComfyHubPublishWizardContent', () => {
|
||||
props: ['onProfileCreated', 'onClose', 'showCloseButton']
|
||||
},
|
||||
Skeleton: {
|
||||
template: '<div class="skeleton" />'
|
||||
template: '<div data-testid="skeleton" class="skeleton" />'
|
||||
},
|
||||
ComfyHubDescribeStep: {
|
||||
template: '<div data-testid="describe-step" />'
|
||||
@@ -333,4 +333,50 @@ describe('ComfyHubPublishWizardContent', () => {
|
||||
expect(screen.getByTestId('publish-footer')).toBeTruthy()
|
||||
})
|
||||
})
|
||||
|
||||
describe('step rendering and footer routing', () => {
|
||||
it('renders the describe step', () => {
|
||||
renderComponent({ currentStep: 'describe', isFirstStep: true })
|
||||
|
||||
expect(screen.getByTestId('describe-step')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('publish-footer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders thumbnail and examples content on the examples step', () => {
|
||||
renderComponent({ currentStep: 'examples' })
|
||||
|
||||
expect(screen.getByTestId('thumbnail-step')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('examples-step')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders a loading state while profile data is resolving', () => {
|
||||
mockHasProfile.value = null
|
||||
|
||||
renderComponent({ currentStep: 'finish' })
|
||||
|
||||
expect(screen.getAllByTestId('skeleton')).toHaveLength(2)
|
||||
expect(screen.queryByTestId('finish-step')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders profile prompt when the finish step lacks a profile', async () => {
|
||||
mockHasProfile.value = false
|
||||
mockProfile.value = null
|
||||
|
||||
renderComponent({ currentStep: 'finish' })
|
||||
await userEvent.click(screen.getByTestId('request-profile'))
|
||||
|
||||
expect(screen.getByTestId('profile-prompt')).toBeInTheDocument()
|
||||
expect(onRequireProfile).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('routes footer next and back events', async () => {
|
||||
renderComponent({ currentStep: 'describe', isFirstStep: true })
|
||||
|
||||
await userEvent.click(screen.getByTestId('next-btn'))
|
||||
await userEvent.click(screen.getByTestId('back-btn'))
|
||||
|
||||
expect(onGoNext).toHaveBeenCalledOnce()
|
||||
expect(onGoBack).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type * as VueUse from '@vueuse/core'
|
||||
|
||||
type TestDropZoneOptions = {
|
||||
dataTypes?: (types: readonly string[]) => boolean
|
||||
onDrop?: (files: File[] | null | undefined) => void
|
||||
}
|
||||
|
||||
const vueUseMocks = vi.hoisted(() => ({
|
||||
dropZoneOptions: [] as TestDropZoneOptions[]
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async (importOriginal) => {
|
||||
const actual = await importOriginal()
|
||||
@@ -10,10 +22,64 @@ vi.mock('vue-i18n', async (importOriginal) => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@vueuse/core', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof VueUse>()
|
||||
const { ref } = await import('vue')
|
||||
|
||||
return {
|
||||
...actual,
|
||||
useDropZone: vi.fn((_target: unknown, options: TestDropZoneOptions) => {
|
||||
vueUseMocks.dropZoneOptions.push(options)
|
||||
return { isOverDropZone: ref(false) }
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
import type { ThumbnailType } from '@/platform/workflow/sharing/types/comfyHubTypes'
|
||||
import {
|
||||
MAX_IMAGE_SIZE_MB,
|
||||
MAX_VIDEO_SIZE_MB
|
||||
} from '@/platform/workflow/sharing/utils/validateFileSize'
|
||||
|
||||
import ComfyHubThumbnailStep from './ComfyHubThumbnailStep.vue'
|
||||
|
||||
function createFile(name: string, type: string, size = 7): File {
|
||||
const file = new File(['content'], name, { type })
|
||||
Object.defineProperty(file, 'size', {
|
||||
configurable: true,
|
||||
value: size
|
||||
})
|
||||
return file
|
||||
}
|
||||
|
||||
function getDropZoneOptions(index: number): TestDropZoneOptions {
|
||||
const options = vueUseMocks.dropZoneOptions[index]
|
||||
if (!options) {
|
||||
throw new Error(`Missing drop zone options at index ${index}`)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
function getDropDataTypes(
|
||||
index: number
|
||||
): (types: readonly string[]) => boolean {
|
||||
const dataTypes = getDropZoneOptions(index).dataTypes
|
||||
if (!dataTypes) {
|
||||
throw new Error(`Missing dataTypes handler at index ${index}`)
|
||||
}
|
||||
return dataTypes
|
||||
}
|
||||
|
||||
function getDropHandler(
|
||||
index: number
|
||||
): (files: File[] | null | undefined) => void {
|
||||
const onDrop = getDropZoneOptions(index).onDrop
|
||||
if (!onDrop) {
|
||||
throw new Error(`Missing drop handler at index ${index}`)
|
||||
}
|
||||
return onDrop
|
||||
}
|
||||
|
||||
function renderStep(
|
||||
props: Record<string, unknown> = {},
|
||||
callbacks: Record<string, ReturnType<typeof vi.fn>> = {}
|
||||
@@ -25,7 +91,7 @@ function renderStep(
|
||||
stubs: {
|
||||
ToggleGroup: {
|
||||
template:
|
||||
'<div><button data-testid="type-image" @click="$emit(\'update:modelValue\', \'image\')" /><button data-testid="type-video" @click="$emit(\'update:modelValue\', \'video\')" /><button data-testid="type-comparison" @click="$emit(\'update:modelValue\', \'imageComparison\')" /><slot /></div>'
|
||||
'<div><button data-testid="type-image" @click="$emit(\'update:modelValue\', \'image\')" /><button data-testid="type-video" @click="$emit(\'update:modelValue\', \'video\')" /><button data-testid="type-comparison" @click="$emit(\'update:modelValue\', \'imageComparison\')" /><button data-testid="type-invalid" @click="$emit(\'update:modelValue\', \'audio\')" /><slot /></div>'
|
||||
},
|
||||
ToggleGroupItem: { template: '<div><slot /></div>', props: ['value'] },
|
||||
Button: {
|
||||
@@ -37,7 +103,33 @@ function renderStep(
|
||||
})
|
||||
}
|
||||
|
||||
function getFileInput(name: string | RegExp) {
|
||||
const labelContent = [
|
||||
...screen.queryAllByText(name),
|
||||
...screen.queryAllByAltText(name)
|
||||
// eslint-disable-next-line testing-library/no-node-access
|
||||
].find((element) => element.closest('label'))
|
||||
// eslint-disable-next-line testing-library/no-node-access
|
||||
const label = labelContent?.closest('label')
|
||||
// eslint-disable-next-line testing-library/no-node-access
|
||||
const input = label?.querySelector('input[type="file"]')
|
||||
if (!(input instanceof HTMLInputElement)) {
|
||||
throw new Error(`Missing file input for ${String(name)}`)
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
describe('ComfyHubThumbnailStep', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
vi.clearAllMocks()
|
||||
vueUseMocks.dropZoneOptions.length = 0
|
||||
Object.defineProperty(URL, 'createObjectURL', {
|
||||
configurable: true,
|
||||
value: vi.fn((file: File) => `blob:${file.name}`)
|
||||
})
|
||||
})
|
||||
|
||||
it('shows the existing image thumbnail on the image tab', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'image',
|
||||
@@ -51,6 +143,19 @@ describe('ComfyHubThumbnailStep', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('shows the upload prompt when the restored image URL is empty', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'image',
|
||||
thumbnailUrl: null,
|
||||
existingThumbnailType: 'image'
|
||||
})
|
||||
|
||||
expect(screen.queryByRole('img')).toBeNull()
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadPromptClickToBrowse')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('does not show an existing image thumbnail on the video tab', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'video',
|
||||
@@ -180,21 +285,35 @@ describe('ComfyHubThumbnailStep', () => {
|
||||
})
|
||||
|
||||
it('restores both comparison images on the comparison tab', () => {
|
||||
const { container } = renderStep({
|
||||
renderStep({
|
||||
thumbnailType: 'imageComparison',
|
||||
thumbnailUrl: 'https://cdn.example.com/before.png',
|
||||
comparisonAfterUrl: 'https://cdn.example.com/after.png',
|
||||
existingThumbnailType: 'imageComparison'
|
||||
})
|
||||
|
||||
// eslint-disable-next-line testing-library/no-node-access, testing-library/no-container
|
||||
const srcs = Array.from(container.querySelectorAll('img')).map((el) =>
|
||||
el.getAttribute('src')
|
||||
)
|
||||
const srcs = screen.getAllByRole('img').map((el) => el.getAttribute('src'))
|
||||
expect(srcs).toContain('https://cdn.example.com/before.png')
|
||||
expect(srcs).toContain('https://cdn.example.com/after.png')
|
||||
})
|
||||
|
||||
it('shows comparison prompts when restored comparison URLs are empty', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'imageComparison',
|
||||
thumbnailUrl: null,
|
||||
comparisonAfterUrl: null,
|
||||
existingThumbnailType: 'imageComparison'
|
||||
})
|
||||
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadComparisonBeforePrompt')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadComparisonAfterPrompt')
|
||||
).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('clear-button')).toBeNull()
|
||||
})
|
||||
|
||||
it('clears a restored image thumbnail when removed', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
@@ -245,4 +364,367 @@ describe('ComfyHubThumbnailStep', () => {
|
||||
expect(onUpdateComparisonBeforeFile).toHaveBeenCalledWith(null)
|
||||
expect(onUpdateComparisonAfterFile).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
it('does not show a clear button when the active thumbnail mode is empty', () => {
|
||||
renderStep()
|
||||
|
||||
expect(screen.queryByTestId('clear-button')).toBeNull()
|
||||
})
|
||||
|
||||
it('shows video-mode upload copy', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'video'
|
||||
})
|
||||
|
||||
expect(screen.getByText('comfyHubPublish.uploadVideo')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadPromptDropVideo')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows comparison upload prompts before images are selected', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'imageComparison'
|
||||
})
|
||||
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadComparison')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadComparisonBeforePrompt')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText('comfyHubPublish.uploadComparisonAfterPrompt')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders selected image files from object URLs', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'image',
|
||||
thumbnailFile: createFile('selected.png', 'image/png')
|
||||
})
|
||||
|
||||
expect(screen.getByRole('img')).toHaveAttribute('src', 'blob:selected.png')
|
||||
})
|
||||
|
||||
it('renders selected video files as a video preview', () => {
|
||||
renderStep({
|
||||
thumbnailType: 'video',
|
||||
thumbnailFile: createFile('selected.mp4', 'video/mp4')
|
||||
})
|
||||
|
||||
expect(
|
||||
screen.getByLabelText('comfyHubPublish.videoPreview')
|
||||
).toHaveAttribute('src', 'blob:selected.mp4')
|
||||
expect(screen.queryByRole('img')).toBeNull()
|
||||
})
|
||||
|
||||
it('ignores invalid thumbnail type updates', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailType = vi.fn()
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
renderStep(
|
||||
{},
|
||||
{
|
||||
'onUpdate:thumbnailType': onUpdateThumbnailType,
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile
|
||||
}
|
||||
)
|
||||
|
||||
await user.click(screen.getByTestId('type-invalid'))
|
||||
|
||||
expect(onUpdateThumbnailType).not.toHaveBeenCalled()
|
||||
expect(onUpdateThumbnailFile).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('selects an image thumbnail file and clears the restored URL', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'image',
|
||||
thumbnailUrl: 'https://cdn.example.com/thumb.png',
|
||||
existingThumbnailType: 'image'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
const file = createFile('thumb.png', 'image/png')
|
||||
|
||||
await user.upload(getFileInput('comfyHubPublish.thumbnailPreview'), file)
|
||||
|
||||
expect(onUpdateThumbnailFile).toHaveBeenCalledWith(file)
|
||||
expect(onUpdateThumbnailUrl).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
it('ignores an empty image thumbnail file selection', () => {
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
|
||||
getFileInput('comfyHubPublish.uploadPromptClickToBrowse').dispatchEvent(
|
||||
new Event('change', { bubbles: true })
|
||||
)
|
||||
|
||||
expect(onUpdateThumbnailFile).not.toHaveBeenCalled()
|
||||
expect(onUpdateThumbnailUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('selects a video thumbnail file in video mode', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'video'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
const file = createFile('clip.mp4', 'video/mp4')
|
||||
|
||||
await user.upload(
|
||||
getFileInput('comfyHubPublish.uploadPromptClickToBrowse'),
|
||||
file
|
||||
)
|
||||
|
||||
expect(onUpdateThumbnailFile).toHaveBeenCalledWith(file)
|
||||
expect(onUpdateThumbnailUrl).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
it('accepts only image drops for image thumbnails', () => {
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
const acceptsSingleDrop = getDropDataTypes(0)
|
||||
|
||||
expect(acceptsSingleDrop(['image/png'])).toBe(true)
|
||||
expect(acceptsSingleDrop(['video/mp4'])).toBe(false)
|
||||
|
||||
getDropHandler(0)([createFile('dropped.png', 'image/png')])
|
||||
getDropHandler(0)([])
|
||||
|
||||
expect(onUpdateThumbnailFile).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ name: 'dropped.png' })
|
||||
)
|
||||
expect(onUpdateThumbnailUrl).toHaveBeenCalledWith(null)
|
||||
expect(onUpdateThumbnailFile).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('accepts video-mode drops for videos and animated images', () => {
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'video'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
const acceptsSingleDrop = getDropDataTypes(0)
|
||||
|
||||
expect(acceptsSingleDrop(['video/mp4'])).toBe(true)
|
||||
expect(acceptsSingleDrop(['image/gif'])).toBe(true)
|
||||
expect(acceptsSingleDrop(['image/webp'])).toBe(true)
|
||||
expect(acceptsSingleDrop(['image/png'])).toBe(false)
|
||||
|
||||
getDropHandler(0)([createFile('clip.mp4', 'video/mp4')])
|
||||
|
||||
expect(onUpdateThumbnailFile).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ name: 'clip.mp4' })
|
||||
)
|
||||
expect(onUpdateThumbnailUrl).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
it('ignores oversized image thumbnail files', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
|
||||
await user.upload(
|
||||
getFileInput('comfyHubPublish.uploadPromptClickToBrowse'),
|
||||
createFile(
|
||||
'too-large.png',
|
||||
'image/png',
|
||||
MAX_IMAGE_SIZE_MB * 1024 * 1024 + 1
|
||||
)
|
||||
)
|
||||
|
||||
expect(onUpdateThumbnailFile).not.toHaveBeenCalled()
|
||||
expect(onUpdateThumbnailUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('ignores oversized video thumbnail files', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailFile = vi.fn()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'video'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailFile': onUpdateThumbnailFile,
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl
|
||||
}
|
||||
)
|
||||
|
||||
await user.upload(
|
||||
getFileInput('comfyHubPublish.uploadPromptClickToBrowse'),
|
||||
createFile(
|
||||
'too-large.mp4',
|
||||
'video/mp4',
|
||||
MAX_VIDEO_SIZE_MB * 1024 * 1024 + 1
|
||||
)
|
||||
)
|
||||
|
||||
expect(onUpdateThumbnailFile).not.toHaveBeenCalled()
|
||||
expect(onUpdateThumbnailUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('selects comparison files independently by slot', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
const onUpdateComparisonAfterUrl = vi.fn()
|
||||
const onUpdateComparisonBeforeFile = vi.fn()
|
||||
const onUpdateComparisonAfterFile = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'imageComparison',
|
||||
thumbnailUrl: 'https://cdn.example.com/before.png',
|
||||
comparisonAfterUrl: 'https://cdn.example.com/after.png',
|
||||
existingThumbnailType: 'imageComparison'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl,
|
||||
'onUpdate:comparisonAfterUrl': onUpdateComparisonAfterUrl,
|
||||
'onUpdate:comparisonBeforeFile': onUpdateComparisonBeforeFile,
|
||||
'onUpdate:comparisonAfterFile': onUpdateComparisonAfterFile
|
||||
}
|
||||
)
|
||||
const before = createFile('before.png', 'image/png')
|
||||
const after = createFile('after.png', 'image/png')
|
||||
|
||||
await user.upload(
|
||||
getFileInput('comfyHubPublish.uploadComparisonBeforePrompt'),
|
||||
before
|
||||
)
|
||||
await user.upload(
|
||||
getFileInput('comfyHubPublish.uploadComparisonAfterPrompt'),
|
||||
after
|
||||
)
|
||||
|
||||
expect(onUpdateComparisonBeforeFile).toHaveBeenCalledWith(before)
|
||||
expect(onUpdateThumbnailUrl).toHaveBeenCalledWith(null)
|
||||
expect(onUpdateComparisonAfterFile).toHaveBeenCalledWith(after)
|
||||
expect(onUpdateComparisonAfterUrl).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
it('ignores an empty comparison file selection', () => {
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
const onUpdateComparisonBeforeFile = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'imageComparison'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl,
|
||||
'onUpdate:comparisonBeforeFile': onUpdateComparisonBeforeFile
|
||||
}
|
||||
)
|
||||
|
||||
getFileInput('comfyHubPublish.uploadComparisonBeforePrompt').dispatchEvent(
|
||||
new Event('change', { bubbles: true })
|
||||
)
|
||||
|
||||
expect(onUpdateComparisonBeforeFile).not.toHaveBeenCalled()
|
||||
expect(onUpdateThumbnailUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('selects comparison images from drop handlers', () => {
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
const onUpdateComparisonAfterUrl = vi.fn()
|
||||
const onUpdateComparisonBeforeFile = vi.fn()
|
||||
const onUpdateComparisonAfterFile = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'imageComparison'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl,
|
||||
'onUpdate:comparisonAfterUrl': onUpdateComparisonAfterUrl,
|
||||
'onUpdate:comparisonBeforeFile': onUpdateComparisonBeforeFile,
|
||||
'onUpdate:comparisonAfterFile': onUpdateComparisonAfterFile
|
||||
}
|
||||
)
|
||||
|
||||
expect(getDropDataTypes(1)(['image/png'])).toBe(true)
|
||||
expect(getDropDataTypes(1)(['video/mp4'])).toBe(false)
|
||||
|
||||
getDropHandler(1)([createFile('before.png', 'image/png')])
|
||||
getDropHandler(2)([createFile('after.png', 'image/png')])
|
||||
getDropHandler(2)(null)
|
||||
|
||||
expect(onUpdateComparisonBeforeFile).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ name: 'before.png' })
|
||||
)
|
||||
expect(onUpdateThumbnailUrl).toHaveBeenCalledWith(null)
|
||||
expect(onUpdateComparisonAfterFile).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ name: 'after.png' })
|
||||
)
|
||||
expect(onUpdateComparisonAfterUrl).toHaveBeenCalledWith(null)
|
||||
expect(onUpdateComparisonAfterFile).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('ignores oversized comparison files', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpdateThumbnailUrl = vi.fn()
|
||||
const onUpdateComparisonBeforeFile = vi.fn()
|
||||
renderStep(
|
||||
{
|
||||
thumbnailType: 'imageComparison'
|
||||
},
|
||||
{
|
||||
'onUpdate:thumbnailUrl': onUpdateThumbnailUrl,
|
||||
'onUpdate:comparisonBeforeFile': onUpdateComparisonBeforeFile
|
||||
}
|
||||
)
|
||||
|
||||
await user.upload(
|
||||
getFileInput('comfyHubPublish.uploadComparisonBeforePrompt'),
|
||||
createFile(
|
||||
'too-large.png',
|
||||
'image/png',
|
||||
MAX_IMAGE_SIZE_MB * 1024 * 1024 + 1
|
||||
)
|
||||
)
|
||||
|
||||
expect(onUpdateComparisonBeforeFile).not.toHaveBeenCalled()
|
||||
expect(onUpdateThumbnailUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
import { fireEvent, render, screen, within } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import ReorderableExampleImage from './ReorderableExampleImage.vue'
|
||||
import type { ExampleImage } from '@/platform/workflow/sharing/types/comfyHubTypes'
|
||||
|
||||
type DragPreviewOptions = {
|
||||
nativeSetDragImage: DataTransfer['setDragImage']
|
||||
render: (args: { container: HTMLElement }) => void
|
||||
}
|
||||
|
||||
type DragSource = {
|
||||
data: {
|
||||
imageId?: string
|
||||
instanceId?: symbol
|
||||
type?: string
|
||||
}
|
||||
}
|
||||
|
||||
type DraggableOptions = {
|
||||
getInitialData: () => DragSource['data']
|
||||
onGenerateDragPreview: (args: {
|
||||
nativeSetDragImage: DataTransfer['setDragImage']
|
||||
}) => void
|
||||
onDragStart: () => void
|
||||
onDrop: () => void
|
||||
}
|
||||
|
||||
type DroppableOptions = {
|
||||
getData: () => { imageId: string }
|
||||
canDrop: (args: { source: DragSource }) => boolean
|
||||
onDragEnter: () => void
|
||||
onDragLeave: () => void
|
||||
onDrop: () => void
|
||||
}
|
||||
|
||||
const pragmatic = vi.hoisted(() => {
|
||||
const captured = {
|
||||
draggable: undefined as DraggableOptions | undefined,
|
||||
droppable: undefined as DroppableOptions | undefined,
|
||||
preview: undefined as DragPreviewOptions | undefined
|
||||
}
|
||||
|
||||
return {
|
||||
captured,
|
||||
setCustomNativeDragPreview: vi.fn((options: DragPreviewOptions) => {
|
||||
captured.preview = options
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock(
|
||||
'@atlaskit/pragmatic-drag-and-drop/element/set-custom-native-drag-preview',
|
||||
() => ({
|
||||
setCustomNativeDragPreview: pragmatic.setCustomNativeDragPreview
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/composables/usePragmaticDragAndDrop', () => ({
|
||||
usePragmaticDraggable: vi.fn(
|
||||
(_target: () => HTMLElement, options: DraggableOptions) => {
|
||||
pragmatic.captured.draggable = options
|
||||
}
|
||||
),
|
||||
usePragmaticDroppable: vi.fn(
|
||||
(_target: () => HTMLElement, options: DroppableOptions) => {
|
||||
pragmatic.captured.droppable = options
|
||||
}
|
||||
)
|
||||
}))
|
||||
|
||||
function createFileList(files: File[]): FileList {
|
||||
return Object.assign(files, {
|
||||
item: (index: number) => files[index] ?? null
|
||||
}) as unknown as FileList
|
||||
}
|
||||
|
||||
function renderImage(overrides: Partial<ExampleImage> = {}) {
|
||||
const instanceId = Symbol('grid')
|
||||
const image: ExampleImage = {
|
||||
id: 'image-1',
|
||||
url: 'blob:image-1',
|
||||
file: new File(['image'], 'image.png', { type: 'image/png' }),
|
||||
...overrides
|
||||
}
|
||||
|
||||
const result = render(ReorderableExampleImage, {
|
||||
props: {
|
||||
image,
|
||||
index: 1,
|
||||
total: 3,
|
||||
instanceId
|
||||
},
|
||||
global: {
|
||||
mocks: {
|
||||
$t: (key: string, params?: Record<string, number>) =>
|
||||
params && params.total
|
||||
? `${key}:${params.index}/${params.total}`
|
||||
: params
|
||||
? `${key}:${params.index}`
|
||||
: key
|
||||
},
|
||||
stubs: {
|
||||
Button: {
|
||||
template:
|
||||
'<button data-testid="remove-button" @click="$emit(\'click\')"><slot /></button>',
|
||||
emits: ['click']
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return { ...result, image, instanceId }
|
||||
}
|
||||
|
||||
function draggableOptions() {
|
||||
const options = pragmatic.captured.draggable
|
||||
if (!options) throw new Error('draggable options were not registered')
|
||||
return options
|
||||
}
|
||||
|
||||
function droppableOptions() {
|
||||
const options = pragmatic.captured.droppable
|
||||
if (!options) throw new Error('droppable options were not registered')
|
||||
return options
|
||||
}
|
||||
|
||||
describe('ReorderableExampleImage', () => {
|
||||
beforeEach(() => {
|
||||
pragmatic.captured.draggable = undefined
|
||||
pragmatic.captured.droppable = undefined
|
||||
pragmatic.captured.preview = undefined
|
||||
pragmatic.setCustomNativeDragPreview.mockClear()
|
||||
})
|
||||
|
||||
it('labels the image position for assistive technology', () => {
|
||||
renderImage()
|
||||
|
||||
expect(
|
||||
screen.getByRole('listitem', {
|
||||
name: 'comfyHubPublish.exampleImagePosition:2/3'
|
||||
})
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByAltText('comfyHubPublish.exampleImage:2')
|
||||
).toHaveAttribute('src', 'blob:image-1')
|
||||
})
|
||||
|
||||
it('emits move for shifted arrow keys', async () => {
|
||||
const { emitted } = renderImage()
|
||||
|
||||
screen.getByRole('listitem').focus()
|
||||
await userEvent.keyboard('{Shift>}{ArrowRight}{/Shift}')
|
||||
|
||||
expect(emitted('move')).toEqual([['image-1', 1]])
|
||||
})
|
||||
|
||||
it('focuses siblings for unshifted arrow keys', async () => {
|
||||
const focus = vi.spyOn(HTMLElement.prototype, 'focus')
|
||||
renderImage()
|
||||
const sibling = document.createElement('button')
|
||||
screen.getByRole('listitem').after(sibling)
|
||||
|
||||
screen.getByRole('listitem').focus()
|
||||
await userEvent.keyboard('{ArrowRight}')
|
||||
|
||||
expect(focus).toHaveBeenCalledWith({ focusVisible: true })
|
||||
expect(sibling).toHaveFocus()
|
||||
focus.mockRestore()
|
||||
})
|
||||
|
||||
it('emits remove from keyboard and focuses the next sibling', async () => {
|
||||
const focus = vi.spyOn(HTMLElement.prototype, 'focus')
|
||||
const { emitted } = renderImage()
|
||||
const sibling = document.createElement('button')
|
||||
screen.getByRole('listitem').after(sibling)
|
||||
|
||||
screen.getByRole('listitem').focus()
|
||||
await userEvent.keyboard('{Delete}')
|
||||
|
||||
expect(emitted('remove')).toEqual([['image-1']])
|
||||
expect(sibling).toHaveFocus()
|
||||
focus.mockRestore()
|
||||
})
|
||||
|
||||
it('emits remove from the remove button', async () => {
|
||||
const { emitted } = renderImage()
|
||||
|
||||
await userEvent.click(screen.getByTestId('remove-button'))
|
||||
|
||||
expect(emitted('remove')).toEqual([['image-1']])
|
||||
})
|
||||
|
||||
it('emits inserted files from file drops', async () => {
|
||||
const { emitted } = renderImage()
|
||||
const files = createFileList([
|
||||
new File(['image'], 'dropped.png', { type: 'image/png' })
|
||||
])
|
||||
|
||||
await fireEvent.drop(screen.getByRole('listitem'), {
|
||||
dataTransfer: { files }
|
||||
})
|
||||
|
||||
expect(emitted('insertFiles')).toEqual([[1, files]])
|
||||
})
|
||||
|
||||
it('ignores drops without files', async () => {
|
||||
const { emitted } = renderImage()
|
||||
|
||||
await fireEvent.drop(screen.getByRole('listitem'), {
|
||||
dataTransfer: { files: createFileList([]) }
|
||||
})
|
||||
|
||||
expect(emitted('insertFiles')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('registers drag data and a cloned preview image', () => {
|
||||
const { instanceId } = renderImage()
|
||||
const draggable = draggableOptions()
|
||||
|
||||
expect(draggable.getInitialData()).toEqual({
|
||||
type: 'example-image',
|
||||
imageId: 'image-1',
|
||||
instanceId
|
||||
})
|
||||
|
||||
const nativeSetDragImage = vi.fn()
|
||||
draggable.onGenerateDragPreview({ nativeSetDragImage })
|
||||
const container = document.createElement('div')
|
||||
pragmatic.captured.preview?.render({ container })
|
||||
|
||||
expect(pragmatic.setCustomNativeDragPreview).toHaveBeenCalledWith({
|
||||
nativeSetDragImage,
|
||||
render: expect.any(Function)
|
||||
})
|
||||
expect(within(container).getByRole('img')).toHaveAttribute(
|
||||
'src',
|
||||
'blob:image-1'
|
||||
)
|
||||
})
|
||||
|
||||
it('accepts drops from other images in the same grid instance only', () => {
|
||||
const { instanceId } = renderImage()
|
||||
const droppable = droppableOptions()
|
||||
|
||||
expect(droppable.getData()).toEqual({ imageId: 'image-1' })
|
||||
expect(
|
||||
droppable.canDrop({
|
||||
source: {
|
||||
data: {
|
||||
type: 'example-image',
|
||||
imageId: 'image-2',
|
||||
instanceId
|
||||
}
|
||||
}
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
droppable.canDrop({
|
||||
source: {
|
||||
data: {
|
||||
type: 'example-image',
|
||||
imageId: 'image-1',
|
||||
instanceId
|
||||
}
|
||||
}
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
droppable.canDrop({
|
||||
source: {
|
||||
data: {
|
||||
type: 'example-image',
|
||||
imageId: 'image-2',
|
||||
instanceId: Symbol('other')
|
||||
}
|
||||
}
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
droppable.canDrop({
|
||||
source: {
|
||||
data: {
|
||||
type: 'other',
|
||||
imageId: 'image-2',
|
||||
instanceId
|
||||
}
|
||||
}
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('handles drag lifecycle callbacks without changing emitted actions', () => {
|
||||
const { emitted } = renderImage()
|
||||
const draggable = draggableOptions()
|
||||
const droppable = droppableOptions()
|
||||
|
||||
draggable.onDragStart()
|
||||
draggable.onDrop()
|
||||
droppable.onDragEnter()
|
||||
droppable.onDragLeave()
|
||||
droppable.onDrop()
|
||||
|
||||
expect(emitted('remove')).toBeUndefined()
|
||||
expect(emitted('move')).toBeUndefined()
|
||||
expect(emitted('insertFiles')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@@ -150,6 +150,28 @@ describe('useComfyHubPublishSubmission', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('uses octet-stream content type when file type is missing', async () => {
|
||||
const thumbnailFile = new File(['thumbnail'], 'thumb.bin')
|
||||
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(
|
||||
createFormData({
|
||||
thumbnailType: 'image',
|
||||
thumbnailFile
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockRequestAssetUploadUrl).toHaveBeenCalledWith({
|
||||
filename: 'thumb.bin',
|
||||
contentType: 'application/octet-stream'
|
||||
})
|
||||
expect(mockUploadFileToPresignedUrl).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contentType: 'application/octet-stream'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('sends the existing thumbnail URL when no new file is attached', async () => {
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(
|
||||
@@ -278,6 +300,27 @@ describe('useComfyHubPublishSubmission', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('keeps existing example image URLs without uploading them', async () => {
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(
|
||||
createFormData({
|
||||
exampleImages: [
|
||||
{
|
||||
id: 'existing',
|
||||
url: 'https://cdn.example.com/existing.png'
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockRequestAssetUploadUrl).not.toHaveBeenCalled()
|
||||
expect(mockPublishWorkflow).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
sampleImageTokensOrUrls: ['https://cdn.example.com/existing.png']
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('builds publish request with workflow filename + asset ids', async () => {
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(createFormData())
|
||||
@@ -294,6 +337,72 @@ describe('useComfyHubPublishSubmission', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('omits optional publish fields when form values are empty', async () => {
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(
|
||||
createFormData({
|
||||
description: '',
|
||||
tags: [],
|
||||
models: [],
|
||||
customNodes: [],
|
||||
tutorialUrl: '',
|
||||
metadata: {}
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockPublishWorkflow).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
description: undefined,
|
||||
tags: undefined,
|
||||
models: undefined,
|
||||
customNodes: undefined,
|
||||
tutorialUrl: undefined,
|
||||
metadata: undefined
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('passes optional publish fields when form values are present', async () => {
|
||||
const metadata = { license: 'cc-by' }
|
||||
const models = ['model']
|
||||
const customNodes = ['custom-node']
|
||||
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(
|
||||
createFormData({
|
||||
models,
|
||||
customNodes,
|
||||
tutorialUrl: 'https://example.com/tutorial',
|
||||
metadata
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockPublishWorkflow).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
models,
|
||||
customNodes,
|
||||
tutorialUrl: 'https://example.com/tutorial',
|
||||
metadata
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('trims the profile username before publishing', async () => {
|
||||
mockProfile.value = {
|
||||
username: ' builder ',
|
||||
name: 'Builder'
|
||||
}
|
||||
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await submitToComfyHub(createFormData())
|
||||
|
||||
expect(mockPublishWorkflow).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
username: 'builder'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when profile username is unavailable', async () => {
|
||||
mockProfile.value = null
|
||||
|
||||
@@ -302,4 +411,13 @@ describe('useComfyHubPublishSubmission', () => {
|
||||
'ComfyHub profile is required before publishing'
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when active workflow path is unavailable', async () => {
|
||||
mockWorkflowStore.activeWorkflow.path = ' '
|
||||
|
||||
const { submitToComfyHub } = useComfyHubPublishSubmission()
|
||||
await expect(submitToComfyHub(createFormData())).rejects.toThrow(
|
||||
'No active workflow file available for publishing'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -12,7 +12,8 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const { useComfyHubPublishWizard } = await import('./useComfyHubPublishWizard')
|
||||
const { cachePublishPrefill, getCachedPrefill, useComfyHubPublishWizard } =
|
||||
await import('./useComfyHubPublishWizard')
|
||||
|
||||
describe('useComfyHubPublishWizard', () => {
|
||||
beforeEach(() => {
|
||||
@@ -179,6 +180,38 @@ describe('useComfyHubPublishWizard', () => {
|
||||
expect(formData.value.thumbnailUrl).toBeNull()
|
||||
})
|
||||
|
||||
it('preserves edited fields when applying a prefill', () => {
|
||||
const { applyPrefill, formData } = useComfyHubPublishWizard()
|
||||
const afterFile = new File(['x'], 'after.png', { type: 'image/png' })
|
||||
const existingExample = {
|
||||
id: 'existing',
|
||||
url: 'https://cdn.example.com/existing.png'
|
||||
}
|
||||
formData.value = {
|
||||
...formData.value,
|
||||
description: 'Edited description',
|
||||
tags: ['edited'],
|
||||
thumbnailType: 'video',
|
||||
comparisonAfterFile: afterFile,
|
||||
exampleImages: [existingExample]
|
||||
}
|
||||
|
||||
applyPrefill({
|
||||
description: 'Restored description',
|
||||
tags: ['restored'],
|
||||
thumbnailType: 'imageComparison',
|
||||
thumbnailComparisonUrl: 'https://cdn.example.com/after.png',
|
||||
sampleImageUrls: ['https://cdn.example.com/sample.png']
|
||||
})
|
||||
|
||||
expect(formData.value.description).toBe('Edited description')
|
||||
expect(formData.value.tags).toEqual(['edited'])
|
||||
expect(formData.value.thumbnailType).toBe('video')
|
||||
expect(formData.value.comparisonAfterFile?.name).toBe(afterFile.name)
|
||||
expect(formData.value.comparisonAfterUrl).toBeNull()
|
||||
expect(formData.value.exampleImages).toEqual([existingExample])
|
||||
})
|
||||
|
||||
it('restores description, tags, and sample images alongside the thumbnail', () => {
|
||||
const { applyPrefill, formData } = useComfyHubPublishWizard()
|
||||
applyPrefill({
|
||||
@@ -198,4 +231,52 @@ describe('useComfyHubPublishWizard', () => {
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('cachePublishPrefill', () => {
|
||||
it('caches normalized prefill data and skips local blob example URLs', () => {
|
||||
const { formData } = useComfyHubPublishWizard()
|
||||
formData.value = {
|
||||
...formData.value,
|
||||
description: 'Saved description',
|
||||
tags: ['Text to Image', ' text to image ', ''],
|
||||
thumbnailType: 'imageComparison',
|
||||
thumbnailUrl: 'https://cdn.example.com/before.png',
|
||||
comparisonAfterUrl: 'https://cdn.example.com/after.png',
|
||||
exampleImages: [
|
||||
{ id: 'local', url: 'blob:local-image' },
|
||||
{ id: 'remote', url: 'https://cdn.example.com/sample.png' }
|
||||
]
|
||||
}
|
||||
|
||||
cachePublishPrefill('/workflows/full', formData.value)
|
||||
|
||||
expect(getCachedPrefill('/workflows/full')).toEqual({
|
||||
description: 'Saved description',
|
||||
tags: ['text-to-image'],
|
||||
thumbnailType: 'imageComparison',
|
||||
thumbnailUrl: 'https://cdn.example.com/before.png',
|
||||
thumbnailComparisonUrl: 'https://cdn.example.com/after.png',
|
||||
sampleImageUrls: ['https://cdn.example.com/sample.png']
|
||||
})
|
||||
})
|
||||
|
||||
it('caches undefined optional fields for empty form data', () => {
|
||||
const { formData } = useComfyHubPublishWizard()
|
||||
|
||||
cachePublishPrefill('/workflows/empty', formData.value)
|
||||
|
||||
expect(getCachedPrefill('/workflows/empty')).toEqual({
|
||||
description: undefined,
|
||||
tags: undefined,
|
||||
thumbnailType: 'image',
|
||||
thumbnailUrl: undefined,
|
||||
thumbnailComparisonUrl: undefined,
|
||||
sampleImageUrls: []
|
||||
})
|
||||
})
|
||||
|
||||
it('returns null when no cached prefill exists', () => {
|
||||
expect(getCachedPrefill('/workflows/missing')).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
250
src/platform/workflow/sharing/composables/useShareDialog.test.ts
Normal file
250
src/platform/workflow/sharing/composables/useShareDialog.test.ts
Normal file
@@ -0,0 +1,250 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
type ActiveWorkflow = {
|
||||
initialMode: string
|
||||
changeTracker?: {
|
||||
activeState?: {
|
||||
extra?: {
|
||||
linearData?: unknown
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ConfirmDialogOptions = {
|
||||
footerProps?: {
|
||||
onCancel?: () => void
|
||||
onConfirm?: () => void
|
||||
}
|
||||
}
|
||||
|
||||
type ShareDialogOptions = {
|
||||
props?: {
|
||||
onClose?: () => void
|
||||
}
|
||||
}
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
closeDialog: vi.fn(),
|
||||
pruneLinearData: vi.fn(),
|
||||
shareFlowContext: {
|
||||
value: {
|
||||
source: 'graph_mode',
|
||||
view_mode: 'default',
|
||||
is_app_mode: false
|
||||
}
|
||||
},
|
||||
showConfirmDialog: vi.fn((..._args: unknown[]) => 'confirm-dialog'),
|
||||
showLayoutDialog: vi.fn(),
|
||||
telemetry: {
|
||||
value: {
|
||||
trackShareFlow: vi.fn()
|
||||
}
|
||||
},
|
||||
workflowStore: {
|
||||
activeWorkflow: null as ActiveWorkflow | null
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/platform/workflow/sharing/components/ShareWorkflowDialogContent.vue',
|
||||
() => ({
|
||||
default: {
|
||||
name: 'ShareWorkflowDialogContent'
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/platform/workflow/sharing/composables/useShareFlowContext', () => ({
|
||||
useShareFlowContext: () => mocks.shareFlowContext
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => mocks.telemetry.value
|
||||
}))
|
||||
|
||||
vi.mock('@/services/dialogService', () => ({
|
||||
useDialogService: () => ({
|
||||
showLayoutDialog: mocks.showLayoutDialog
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: () => ({
|
||||
closeDialog: mocks.closeDialog
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/appModeStore', () => ({
|
||||
useAppModeStore: () => ({
|
||||
pruneLinearData: mocks.pruneLinearData
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('../../management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: () => mocks.workflowStore
|
||||
}))
|
||||
|
||||
vi.mock('@/components/dialog/confirm/confirmDialog', () => ({
|
||||
showConfirmDialog: mocks.showConfirmDialog
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => key
|
||||
}))
|
||||
|
||||
const { useShareDialog } = await import('./useShareDialog')
|
||||
|
||||
describe('useShareDialog', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mocks.workflowStore.activeWorkflow = null
|
||||
mocks.pruneLinearData.mockReturnValue({
|
||||
inputs: [],
|
||||
outputs: ['output-node']
|
||||
})
|
||||
mocks.telemetry.value = {
|
||||
trackShareFlow: vi.fn()
|
||||
}
|
||||
})
|
||||
|
||||
it('closes the global share dialog', () => {
|
||||
const { hide } = useShareDialog()
|
||||
|
||||
hide()
|
||||
|
||||
expect(mocks.closeDialog).toHaveBeenCalledWith({
|
||||
key: 'global-share-workflow'
|
||||
})
|
||||
})
|
||||
|
||||
it('opens the share dialog when there is no active workflow', () => {
|
||||
const { show } = useShareDialog()
|
||||
|
||||
show()
|
||||
|
||||
expect(mocks.telemetry.value.trackShareFlow).toHaveBeenCalledWith({
|
||||
step: 'dialog_opened',
|
||||
source: 'graph_mode',
|
||||
view_mode: 'default',
|
||||
is_app_mode: false
|
||||
})
|
||||
expect(mocks.showLayoutDialog).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
key: 'global-share-workflow',
|
||||
dialogComponentProps: {
|
||||
contentClass: 'sm:max-w-144 rounded-2xl overflow-hidden'
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
const options = mocks.showLayoutDialog.mock.calls[0]?.[0] as
|
||||
| ShareDialogOptions
|
||||
| undefined
|
||||
options?.props?.onClose?.()
|
||||
|
||||
expect(mocks.closeDialog).toHaveBeenCalledWith({
|
||||
key: 'global-share-workflow'
|
||||
})
|
||||
})
|
||||
|
||||
it('asks for confirmation before sharing an app workflow without outputs', () => {
|
||||
mocks.workflowStore.activeWorkflow = {
|
||||
initialMode: 'app',
|
||||
changeTracker: {
|
||||
activeState: {
|
||||
extra: {
|
||||
linearData: {
|
||||
nodes: []
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mocks.pruneLinearData.mockReturnValue({
|
||||
inputs: [],
|
||||
outputs: []
|
||||
})
|
||||
|
||||
const { show } = useShareDialog()
|
||||
|
||||
show()
|
||||
|
||||
expect(mocks.showConfirmDialog).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
headerProps: {
|
||||
title: 'shareNoOutputs.title'
|
||||
},
|
||||
props: {
|
||||
promptText: 'shareNoOutputs.message',
|
||||
preserveNewlines: true
|
||||
}
|
||||
})
|
||||
)
|
||||
expect(mocks.showLayoutDialog).not.toHaveBeenCalled()
|
||||
|
||||
const options = mocks.showConfirmDialog.mock.calls[0]?.[0] as
|
||||
| ConfirmDialogOptions
|
||||
| undefined
|
||||
options?.footerProps?.onConfirm?.()
|
||||
|
||||
expect(mocks.closeDialog).toHaveBeenCalledWith('confirm-dialog')
|
||||
expect(mocks.showLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('keeps the share dialog closed when the no-output confirmation is cancelled', () => {
|
||||
mocks.workflowStore.activeWorkflow = {
|
||||
initialMode: 'app'
|
||||
}
|
||||
mocks.pruneLinearData.mockReturnValue({
|
||||
inputs: [],
|
||||
outputs: []
|
||||
})
|
||||
|
||||
const { show } = useShareDialog()
|
||||
|
||||
show()
|
||||
|
||||
const options = mocks.showConfirmDialog.mock.calls[0]?.[0] as
|
||||
| ConfirmDialogOptions
|
||||
| undefined
|
||||
options?.footerProps?.onCancel?.()
|
||||
|
||||
expect(mocks.closeDialog).toHaveBeenCalledWith('confirm-dialog')
|
||||
expect(mocks.showLayoutDialog).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('opens immediately when app workflow outputs are present', () => {
|
||||
mocks.workflowStore.activeWorkflow = {
|
||||
initialMode: 'app'
|
||||
}
|
||||
mocks.pruneLinearData.mockReturnValue({
|
||||
inputs: [],
|
||||
outputs: ['output-node']
|
||||
})
|
||||
|
||||
const { show } = useShareDialog()
|
||||
|
||||
show()
|
||||
|
||||
expect(mocks.showConfirmDialog).not.toHaveBeenCalled()
|
||||
expect(mocks.showLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('opens immediately for graph workflows without outputs', () => {
|
||||
mocks.workflowStore.activeWorkflow = {
|
||||
initialMode: 'graph'
|
||||
}
|
||||
mocks.pruneLinearData.mockReturnValue({
|
||||
inputs: [],
|
||||
outputs: []
|
||||
})
|
||||
|
||||
const { show } = useShareDialog()
|
||||
|
||||
show()
|
||||
|
||||
expect(mocks.showConfirmDialog).not.toHaveBeenCalled()
|
||||
expect(mocks.showLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,71 @@
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { UseMouseSourceType } from '@vueuse/core'
|
||||
import { useMouseInElement } from '@vueuse/core'
|
||||
import { nextTick, ref } from 'vue'
|
||||
|
||||
import { useSliderFromMouse } from './useSliderFromMouse'
|
||||
|
||||
vi.mock('@vueuse/core', () => ({
|
||||
useMouseInElement: vi.fn()
|
||||
}))
|
||||
|
||||
const elementX = ref(0)
|
||||
const elementWidth = ref(100)
|
||||
const isOutside = ref(true)
|
||||
|
||||
vi.mocked(useMouseInElement).mockReturnValue(
|
||||
fromPartial({
|
||||
elementX,
|
||||
elementY: ref(0),
|
||||
elementPositionX: ref(0),
|
||||
elementPositionY: ref(0),
|
||||
elementHeight: ref(0),
|
||||
elementWidth,
|
||||
isOutside,
|
||||
sourceType: ref<UseMouseSourceType>(null)
|
||||
})
|
||||
)
|
||||
|
||||
describe('useSliderFromMouse', () => {
|
||||
beforeEach(() => {
|
||||
elementX.value = 0
|
||||
elementWidth.value = 100
|
||||
isOutside.value = true
|
||||
})
|
||||
|
||||
it('starts at the midpoint', () => {
|
||||
const target = ref(document.createElement('div'))
|
||||
|
||||
expect(useSliderFromMouse(target).value).toBe(50)
|
||||
})
|
||||
|
||||
it('updates from mouse position while pointer is inside the target', async () => {
|
||||
const target = ref(document.createElement('div'))
|
||||
const position = useSliderFromMouse(target)
|
||||
|
||||
isOutside.value = false
|
||||
elementX.value = 25
|
||||
elementWidth.value = 100
|
||||
await nextTick()
|
||||
|
||||
expect(position.value).toBe(25)
|
||||
})
|
||||
|
||||
it('ignores pointer updates outside the target or without width', async () => {
|
||||
const target = ref(document.createElement('div'))
|
||||
const position = useSliderFromMouse(target)
|
||||
|
||||
isOutside.value = true
|
||||
elementX.value = 10
|
||||
elementWidth.value = 100
|
||||
await nextTick()
|
||||
expect(position.value).toBe(50)
|
||||
|
||||
isOutside.value = false
|
||||
elementWidth.value = 0
|
||||
await nextTick()
|
||||
expect(position.value).toBe(50)
|
||||
})
|
||||
})
|
||||
@@ -27,6 +27,16 @@ function mockUploadResponse(ok = true, status = 200): Response {
|
||||
} as Response
|
||||
}
|
||||
|
||||
function mockJsonFailure(ok = false, status = 500): Response {
|
||||
return {
|
||||
ok,
|
||||
status,
|
||||
json: async () => {
|
||||
throw new Error('invalid json')
|
||||
}
|
||||
} as unknown as Response
|
||||
}
|
||||
|
||||
describe('useComfyHubService', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
@@ -171,6 +181,29 @@ describe('useComfyHubService', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('publishes workflow with hub-native thumbnail type', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({
|
||||
share_id: 'share-1',
|
||||
workflow_id: 'workflow-1',
|
||||
thumbnail_type: 'video'
|
||||
})
|
||||
)
|
||||
|
||||
const service = useComfyHubService()
|
||||
await service.publishWorkflow({
|
||||
username: 'builder',
|
||||
name: 'Video Flow',
|
||||
workflowFilename: 'workflows/video-flow.json',
|
||||
assetIds: ['asset-1'],
|
||||
thumbnailType: 'video'
|
||||
})
|
||||
|
||||
const [, options] = mockFetchApi.mock.calls[0]
|
||||
const body = JSON.parse(options.body as string)
|
||||
expect(body.thumbnail_type).toBe('video')
|
||||
})
|
||||
|
||||
it('fetches tag labels from /hub/labels?type=tag', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({
|
||||
@@ -212,4 +245,107 @@ describe('useComfyHubService', () => {
|
||||
coverImageUrl: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('returns null when current profile is missing', async () => {
|
||||
mockFetchApi.mockResolvedValue(mockJsonResponse({}, false, 404))
|
||||
|
||||
const service = useComfyHubService()
|
||||
|
||||
await expect(service.getMyProfile()).resolves.toBeNull()
|
||||
})
|
||||
|
||||
it('uses server error messages when requests fail', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({ message: 'No upload for you' }, false, 400)
|
||||
)
|
||||
|
||||
const service = useComfyHubService()
|
||||
|
||||
await expect(
|
||||
service.requestAssetUploadUrl({
|
||||
filename: 'thumb.png',
|
||||
contentType: 'image/png'
|
||||
})
|
||||
).rejects.toThrow('No upload for you')
|
||||
})
|
||||
|
||||
it('uses fallback error messages when error bodies are missing or malformed', async () => {
|
||||
mockFetchApi.mockResolvedValueOnce(mockJsonFailure())
|
||||
const service = useComfyHubService()
|
||||
|
||||
await expect(
|
||||
service.requestAssetUploadUrl({
|
||||
filename: 'thumb.png',
|
||||
contentType: 'image/png'
|
||||
})
|
||||
).rejects.toThrow('Failed to request upload URL')
|
||||
|
||||
mockFetchApi.mockResolvedValueOnce(mockJsonResponse({}, false, 500))
|
||||
|
||||
await expect(service.getMyProfile()).rejects.toThrow(
|
||||
'Failed to load ComfyHub profile'
|
||||
)
|
||||
})
|
||||
|
||||
it('throws on invalid success payloads', async () => {
|
||||
mockFetchApi.mockResolvedValue(mockJsonResponse({ invalid: true }))
|
||||
|
||||
const service = useComfyHubService()
|
||||
|
||||
await expect(service.fetchTagLabels()).rejects.toThrow(
|
||||
'Invalid label list response from server'
|
||||
)
|
||||
})
|
||||
|
||||
it('throws upload errors from presigned URL uploads', async () => {
|
||||
mockGlobalFetch.mockResolvedValue(
|
||||
mockJsonResponse({ message: 'Upload rejected' }, false, 403)
|
||||
)
|
||||
|
||||
const service = useComfyHubService()
|
||||
const file = new File(['payload'], 'avatar.png', { type: 'image/png' })
|
||||
|
||||
await expect(
|
||||
service.uploadFileToPresignedUrl({
|
||||
uploadUrl: 'https://upload.example.com/object',
|
||||
file,
|
||||
contentType: 'image/png'
|
||||
})
|
||||
).rejects.toThrow('Upload rejected')
|
||||
})
|
||||
|
||||
it('throws create and publish failures with parsed fallback messages', async () => {
|
||||
const service = useComfyHubService()
|
||||
mockFetchApi.mockResolvedValueOnce(mockJsonResponse({}, false, 500))
|
||||
|
||||
await expect(
|
||||
service.createProfile({
|
||||
workspaceId: 'workspace-1',
|
||||
username: 'builder'
|
||||
})
|
||||
).rejects.toThrow('Failed to create ComfyHub profile')
|
||||
|
||||
mockFetchApi.mockResolvedValueOnce(
|
||||
mockJsonResponse({ message: 'Publish rejected' }, false, 400)
|
||||
)
|
||||
|
||||
await expect(
|
||||
service.publishWorkflow({
|
||||
username: 'builder',
|
||||
name: 'My Flow',
|
||||
workflowFilename: 'workflows/my-flow.json',
|
||||
assetIds: ['asset-1']
|
||||
})
|
||||
).rejects.toThrow('Publish rejected')
|
||||
})
|
||||
|
||||
it('throws label fetch failures', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({ message: 'Labels unavailable' }, false, 503)
|
||||
)
|
||||
|
||||
const service = useComfyHubService()
|
||||
|
||||
await expect(service.fetchTagLabels()).rejects.toThrow('Labels unavailable')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -16,6 +16,11 @@ const mockGetShareableAssets = vi.fn()
|
||||
const mockFetchApi = vi.fn()
|
||||
const mockInvalidateInputAssetsIncludingPublic = vi.hoisted(() => vi.fn())
|
||||
|
||||
type RetryableLoadError = Error & {
|
||||
status: number | null
|
||||
isRetryable: boolean
|
||||
}
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
getShareableAssets: (...args: unknown[]) => mockGetShareableAssets(...args),
|
||||
@@ -54,14 +59,43 @@ describe(useWorkflowShareService, () => {
|
||||
}
|
||||
]
|
||||
|
||||
function mockJsonResponse(payload: unknown, ok = true, status = 200) {
|
||||
function mockJsonResponse(
|
||||
payload: unknown,
|
||||
ok = true,
|
||||
status = 200,
|
||||
statusText = ''
|
||||
) {
|
||||
return {
|
||||
ok,
|
||||
status,
|
||||
statusText,
|
||||
json: async () => payload
|
||||
} as Response
|
||||
}
|
||||
|
||||
async function expectLoadError(
|
||||
promise: Promise<unknown>,
|
||||
expected: {
|
||||
message: string
|
||||
status: number | null
|
||||
isRetryable: boolean
|
||||
}
|
||||
) {
|
||||
let caught: unknown
|
||||
|
||||
try {
|
||||
await promise
|
||||
} catch (error) {
|
||||
caught = error
|
||||
}
|
||||
|
||||
expect(caught).toBeInstanceOf(Error)
|
||||
const loadError = caught as RetryableLoadError
|
||||
expect(loadError.message).toBe(expected.message)
|
||||
expect(loadError.status).toBe(expected.status)
|
||||
expect(loadError.isRetryable).toBe(expected.isRetryable)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
mockApp.rootGraph = {}
|
||||
@@ -87,6 +121,33 @@ describe(useWorkflowShareService, () => {
|
||||
expect(status.publishedAt).toBeNull()
|
||||
})
|
||||
|
||||
it('returns unpublished when publish status does not exist', async () => {
|
||||
mockFetchApi.mockResolvedValue(mockJsonResponse({}, false, 404))
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
const status = await service.getPublishStatus('missing')
|
||||
|
||||
expect(status).toEqual({
|
||||
isPublished: false,
|
||||
shareId: null,
|
||||
shareUrl: null,
|
||||
publishedAt: null,
|
||||
prefill: null
|
||||
})
|
||||
})
|
||||
|
||||
it('throws when publish status request fails', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({}, false, 503, 'Service Unavailable')
|
||||
)
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
await expect(service.getPublishStatus('wf-error')).rejects.toThrow(
|
||||
'Failed to fetch publish status: 503 Service Unavailable'
|
||||
)
|
||||
})
|
||||
|
||||
it('publishes a workflow and returns a share URL', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({
|
||||
@@ -120,6 +181,56 @@ describe(useWorkflowShareService, () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when publish request fails', async () => {
|
||||
mockFetchApi.mockResolvedValue(mockJsonResponse({}, false, 500))
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
await expect(
|
||||
service.publishWorkflow('wf-error', mockShareableAssets)
|
||||
).rejects.toThrow('Failed to publish workflow: 500')
|
||||
})
|
||||
|
||||
it('throws when publish response is missing required publish data', async () => {
|
||||
mockFetchApi
|
||||
.mockResolvedValueOnce(
|
||||
mockJsonResponse({
|
||||
workflow_id: 'wf-no-share',
|
||||
share_id: null,
|
||||
publish_time: '2026-02-23T00:00:00Z',
|
||||
listed: false
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
mockJsonResponse({
|
||||
workflow_id: 'wf-no-date',
|
||||
share_id: 'wf-no-date',
|
||||
publish_time: null,
|
||||
listed: false
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
mockJsonResponse({
|
||||
workflow_id: 'wf-invalid-date',
|
||||
share_id: 'wf-invalid-date',
|
||||
publish_time: 'invalid-date',
|
||||
listed: false
|
||||
})
|
||||
)
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
await expect(
|
||||
service.publishWorkflow('wf-no-share', mockShareableAssets)
|
||||
).rejects.toThrow('Failed to publish workflow: invalid response')
|
||||
await expect(
|
||||
service.publishWorkflow('wf-no-date', mockShareableAssets)
|
||||
).rejects.toThrow('Failed to publish workflow: invalid response')
|
||||
await expect(
|
||||
service.publishWorkflow('wf-invalid-date', mockShareableAssets)
|
||||
).rejects.toThrow('Failed to publish workflow: invalid response')
|
||||
})
|
||||
|
||||
it('preserves app subpath when normalizing published share URLs', async () => {
|
||||
window.history.replaceState({}, '', '/comfy/subpath/?foo=bar#section')
|
||||
mockFetchApi.mockResolvedValue(
|
||||
@@ -202,6 +313,73 @@ describe(useWorkflowShareService, () => {
|
||||
expect(mockFetchApi).toHaveBeenNthCalledWith(2, '/hub/workflows/wf-prefill')
|
||||
})
|
||||
|
||||
it('maps listed hub workflow media prefill fields', async () => {
|
||||
mockFetchApi.mockImplementation(async (path: string) => {
|
||||
if (path === '/userdata/wf-media/publish') {
|
||||
return mockJsonResponse({
|
||||
workflow_id: 'wf-media',
|
||||
share_id: 'wf-media',
|
||||
publish_time: '2026-02-23T00:00:00Z',
|
||||
listed: true
|
||||
})
|
||||
}
|
||||
|
||||
if (path === '/hub/workflows/wf-media') {
|
||||
return mockJsonResponse({
|
||||
tags: ['motion'],
|
||||
thumbnail_type: 'video',
|
||||
thumbnail_url: 'https://example.com/thumb.mp4',
|
||||
thumbnail_comparison_url: 'https://example.com/compare.png'
|
||||
})
|
||||
}
|
||||
|
||||
return mockJsonResponse({}, false, 404)
|
||||
})
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
const status = await service.getPublishStatus('wf-media')
|
||||
|
||||
expect(status.isPublished).toBe(true)
|
||||
expect(status.prefill).toEqual({
|
||||
tags: ['motion'],
|
||||
thumbnailType: 'video',
|
||||
thumbnailUrl: 'https://example.com/thumb.mp4',
|
||||
thumbnailComparisonUrl: 'https://example.com/compare.png'
|
||||
})
|
||||
})
|
||||
|
||||
it('returns null listed prefill when hub metadata has no fields', async () => {
|
||||
mockFetchApi.mockImplementation(async (path: string) => {
|
||||
if (path === '/userdata/wf-empty-prefill/publish') {
|
||||
return mockJsonResponse({
|
||||
workflow_id: 'wf-empty-prefill',
|
||||
share_id: 'wf-empty-prefill',
|
||||
publish_time: '2026-02-23T00:00:00Z',
|
||||
listed: true
|
||||
})
|
||||
}
|
||||
|
||||
if (path === '/hub/workflows/wf-empty-prefill') {
|
||||
return mockJsonResponse({
|
||||
description: null,
|
||||
tags: [],
|
||||
thumbnail_type: null,
|
||||
thumbnail_url: null,
|
||||
thumbnail_comparison_url: null,
|
||||
sample_image_urls: []
|
||||
})
|
||||
}
|
||||
|
||||
return mockJsonResponse({}, false, 404)
|
||||
})
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
const status = await service.getPublishStatus('wf-empty-prefill')
|
||||
|
||||
expect(status.isPublished).toBe(true)
|
||||
expect(status.prefill).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null prefill when hub workflow details are unavailable', async () => {
|
||||
mockFetchApi.mockImplementation(async (path: string) => {
|
||||
if (path === '/userdata/wf-no-meta/publish') {
|
||||
@@ -279,6 +457,23 @@ describe(useWorkflowShareService, () => {
|
||||
expect(status.shareId).toBeNull()
|
||||
})
|
||||
|
||||
it('returns unpublished when publish record has invalid publish time', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({
|
||||
workflow_id: 'wf-invalid-time',
|
||||
share_id: 'wf-invalid-time',
|
||||
publish_time: 'invalid-date',
|
||||
listed: false
|
||||
})
|
||||
)
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
const status = await service.getPublishStatus('wf-invalid-time')
|
||||
|
||||
expect(status.isPublished).toBe(false)
|
||||
expect(status.shareId).toBeNull()
|
||||
})
|
||||
|
||||
it('fetches and maps shared workflow payload', async () => {
|
||||
mockFetchApi.mockResolvedValue(
|
||||
mockJsonResponse({
|
||||
@@ -332,9 +527,48 @@ describe(useWorkflowShareService, () => {
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
await expect(service.getSharedWorkflow('missing')).rejects.toThrow(
|
||||
'Failed to load shared workflow: 404'
|
||||
)
|
||||
await expectLoadError(service.getSharedWorkflow('missing'), {
|
||||
message: 'Failed to load shared workflow: 404',
|
||||
status: 404,
|
||||
isRetryable: false
|
||||
})
|
||||
})
|
||||
|
||||
it('marks retryable status errors on shared workflow failures', async () => {
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
mockFetchApi.mockResolvedValueOnce(mockJsonResponse({}, false, 500))
|
||||
await expectLoadError(service.getSharedWorkflow('server-error'), {
|
||||
message: 'Failed to load shared workflow: 500',
|
||||
status: 500,
|
||||
isRetryable: true
|
||||
})
|
||||
|
||||
mockFetchApi.mockResolvedValueOnce(mockJsonResponse({}, false, 408))
|
||||
await expectLoadError(service.getSharedWorkflow('timeout'), {
|
||||
message: 'Failed to load shared workflow: 408',
|
||||
status: 408,
|
||||
isRetryable: true
|
||||
})
|
||||
|
||||
mockFetchApi.mockResolvedValueOnce(mockJsonResponse({}, false, 429))
|
||||
await expectLoadError(service.getSharedWorkflow('rate-limited'), {
|
||||
message: 'Failed to load shared workflow: 429',
|
||||
status: 429,
|
||||
isRetryable: true
|
||||
})
|
||||
})
|
||||
|
||||
it('marks shared workflow network failures as retryable', async () => {
|
||||
mockFetchApi.mockRejectedValue(new TypeError('network down'))
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
await expectLoadError(service.getSharedWorkflow('network-error'), {
|
||||
message: 'Failed to load shared workflow: network error',
|
||||
status: null,
|
||||
isRetryable: true
|
||||
})
|
||||
})
|
||||
|
||||
it('imports published assets via POST /assets/import with share_id', async () => {
|
||||
@@ -468,6 +702,66 @@ describe(useWorkflowShareService, () => {
|
||||
expect(mockGetShareableAssets).toHaveBeenCalledWith({ '1': {} })
|
||||
})
|
||||
|
||||
it('filters public shareable assets unless explicitly included', async () => {
|
||||
mockApp.graphToPrompt.mockResolvedValue({ output: {} })
|
||||
mockGetShareableAssets.mockResolvedValue({
|
||||
assets: [
|
||||
{
|
||||
id: 'private-asset',
|
||||
name: 'private.png',
|
||||
preview_url: '',
|
||||
storage_url: '',
|
||||
model: false,
|
||||
public: false,
|
||||
in_library: false
|
||||
},
|
||||
{
|
||||
id: 'public-asset',
|
||||
name: 'public.png',
|
||||
preview_url: '',
|
||||
storage_url: '',
|
||||
model: false,
|
||||
public: true,
|
||||
in_library: false
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const service = useWorkflowShareService()
|
||||
|
||||
await expect(service.getShareableAssets()).resolves.toEqual([
|
||||
{
|
||||
id: 'private-asset',
|
||||
name: 'private.png',
|
||||
preview_url: '',
|
||||
storage_url: '',
|
||||
model: false,
|
||||
public: false,
|
||||
in_library: false
|
||||
}
|
||||
])
|
||||
await expect(service.getShareableAssets(true)).resolves.toEqual([
|
||||
{
|
||||
id: 'private-asset',
|
||||
name: 'private.png',
|
||||
preview_url: '',
|
||||
storage_url: '',
|
||||
model: false,
|
||||
public: false,
|
||||
in_library: false
|
||||
},
|
||||
{
|
||||
id: 'public-asset',
|
||||
name: 'public.png',
|
||||
preview_url: '',
|
||||
storage_url: '',
|
||||
model: false,
|
||||
public: true,
|
||||
in_library: false
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('propagates error when graphToPrompt fails', async () => {
|
||||
mockApp.graphToPrompt.mockRejectedValue(new Error('prompt failed'))
|
||||
|
||||
|
||||
@@ -0,0 +1,454 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import { useWorkflowTemplatesStore } from '@/platform/workflow/templates/repositories/workflowTemplatesStore'
|
||||
import type { WorkflowTemplates } from '@/platform/workflow/templates/types/template'
|
||||
import type { NavGroupData, NavItemData } from '@/types/navTypes'
|
||||
|
||||
const {
|
||||
coreByLocale,
|
||||
coreErrorLocales,
|
||||
coreResult,
|
||||
customResult,
|
||||
dist,
|
||||
locale
|
||||
} = vi.hoisted(() => ({
|
||||
coreByLocale: { value: {} as Record<string, unknown[]> },
|
||||
coreErrorLocales: { value: new Set<string>() },
|
||||
coreResult: { value: [] as unknown[] },
|
||||
customResult: { value: {} as Record<string, string[]> },
|
||||
dist: { isCloud: false },
|
||||
locale: { value: 'en' }
|
||||
}))
|
||||
|
||||
const baseTemplate = {
|
||||
name: 'default',
|
||||
title: 'Default',
|
||||
description: 'A basic template',
|
||||
mediaType: 'image',
|
||||
mediaSubtype: 'webp'
|
||||
}
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
getWorkflowTemplates: async () => customResult.value,
|
||||
getCoreWorkflowTemplates: async (locale: string) => {
|
||||
if (coreErrorLocales.value.has(locale)) throw new Error('core failed')
|
||||
return coreByLocale.value[locale] ?? coreResult.value
|
||||
},
|
||||
fileURL: (p: string) => p
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', async () => {
|
||||
const { ref } = await import('vue')
|
||||
const localeRef = ref(locale.value)
|
||||
Object.defineProperty(locale, 'value', {
|
||||
get: () => localeRef.value,
|
||||
set: (value: string) => {
|
||||
localeRef.value = value
|
||||
}
|
||||
})
|
||||
return {
|
||||
i18n: { global: { locale } },
|
||||
st: (_key: string, fallback: string) => fallback
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
return dist.isCloud
|
||||
}
|
||||
}))
|
||||
|
||||
function coreCategory(
|
||||
overrides: Partial<WorkflowTemplates> = {}
|
||||
): WorkflowTemplates {
|
||||
return {
|
||||
moduleName: 'default',
|
||||
title: 'Basics',
|
||||
type: 'image',
|
||||
templates: [baseTemplate],
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function navItems(items: (NavItemData | NavGroupData)[]) {
|
||||
return items.flatMap((item) => ('items' in item ? item.items : [item]))
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
coreByLocale.value = {}
|
||||
coreErrorLocales.value = new Set()
|
||||
coreResult.value = [coreCategory()]
|
||||
customResult.value = {}
|
||||
dist.isCloud = false
|
||||
locale.value = 'en'
|
||||
vi.stubGlobal(
|
||||
'fetch',
|
||||
vi.fn(
|
||||
async () => new Response('', { headers: { 'content-type': 'text/html' } })
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
describe('workflowTemplatesStore', () => {
|
||||
it('returns empty navigation before templates are loaded', () => {
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
expect(store.navGroupedTemplates).toEqual([])
|
||||
})
|
||||
|
||||
it('loads core templates and indexes their names', async () => {
|
||||
const store = useWorkflowTemplatesStore()
|
||||
expect(store.isLoaded).toBe(false)
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.isLoaded).toBe(true)
|
||||
expect(store.knownTemplateNames.has('default')).toBe(true)
|
||||
expect(store.getTemplateByName('default')?.name).toBe('default')
|
||||
expect(store.getTemplateByName('missing')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('exposes grouped templates with localized titles', async () => {
|
||||
const store = useWorkflowTemplatesStore()
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.groupedTemplates.length).toBeGreaterThan(0)
|
||||
const allNames = store.groupedTemplates.flatMap((g) =>
|
||||
(g.modules ?? []).flatMap((m) => (m.templates ?? []).map((t) => t.name))
|
||||
)
|
||||
expect(allNames).toContain('default')
|
||||
})
|
||||
|
||||
it('filters nav categories from loaded template metadata', async () => {
|
||||
coreResult.value = [
|
||||
coreCategory({
|
||||
title: 'Getting Started',
|
||||
isEssential: true,
|
||||
templates: [{ ...baseTemplate, name: 'starter', title: 'Starter' }]
|
||||
}),
|
||||
coreCategory({
|
||||
title: 'Image Tools',
|
||||
category: 'GENERATION TYPE',
|
||||
templates: [
|
||||
{
|
||||
...baseTemplate,
|
||||
name: 'partner-upscale',
|
||||
title: 'Partner Upscale',
|
||||
openSource: false
|
||||
},
|
||||
{
|
||||
...baseTemplate,
|
||||
name: 'local-only',
|
||||
requiresCustomNodes: ['custom-node']
|
||||
}
|
||||
]
|
||||
}),
|
||||
coreCategory({
|
||||
title: 'Image Tools',
|
||||
category: 'OTHER GROUP',
|
||||
templates: [
|
||||
{
|
||||
...baseTemplate,
|
||||
name: 'other-image',
|
||||
title: 'Other Image'
|
||||
}
|
||||
]
|
||||
}),
|
||||
coreCategory({
|
||||
title: 'Video Tools',
|
||||
category: 'GENERATION TYPE',
|
||||
icon: 'icon-custom',
|
||||
type: undefined,
|
||||
templates: [
|
||||
{
|
||||
...baseTemplate,
|
||||
name: 'video-tool',
|
||||
title: 'Video Tool'
|
||||
}
|
||||
]
|
||||
})
|
||||
]
|
||||
customResult.value = { CustomPack: ['custom-flow'] }
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
const allItems = navItems(store.navGroupedTemplates)
|
||||
const basicsId = allItems.find(
|
||||
(item) => item.label === 'Getting Started'
|
||||
)?.id
|
||||
const categoryId = allItems.find((item) => item.label === 'Image Tools')?.id
|
||||
|
||||
expect(store.filterTemplatesByCategory('all').map((t) => t.name)).toEqual([
|
||||
'starter',
|
||||
'partner-upscale',
|
||||
'other-image',
|
||||
'video-tool',
|
||||
'custom-flow'
|
||||
])
|
||||
expect(
|
||||
store.filterTemplatesByCategory('popular').map((t) => t.name)
|
||||
).toEqual([
|
||||
'starter',
|
||||
'partner-upscale',
|
||||
'other-image',
|
||||
'video-tool',
|
||||
'custom-flow'
|
||||
])
|
||||
expect(
|
||||
store.filterTemplatesByCategory(basicsId ?? '').map((t) => t.name)
|
||||
).toEqual(['starter'])
|
||||
expect(
|
||||
store.filterTemplatesByCategory(categoryId ?? '').map((t) => t.name)
|
||||
).toEqual(['partner-upscale'])
|
||||
expect(
|
||||
store.filterTemplatesByCategory('partner-nodes').map((t) => t.name)
|
||||
).toEqual(['partner-upscale'])
|
||||
expect(
|
||||
store.filterTemplatesByCategory('extension-CustomPack').map((t) => t.name)
|
||||
).toEqual(['custom-flow'])
|
||||
expect(
|
||||
store.filterTemplatesByCategory('unknown').map((t) => t.name)
|
||||
).toEqual([
|
||||
'starter',
|
||||
'partner-upscale',
|
||||
'other-image',
|
||||
'video-tool',
|
||||
'custom-flow'
|
||||
])
|
||||
})
|
||||
|
||||
it('loads logo indexes and rejects unsafe logo paths', async () => {
|
||||
vi.mocked(fetch).mockResolvedValueOnce(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
valid: 'logos/valid.svg',
|
||||
missingExtension: 'logos/valid',
|
||||
parent: '../secret.svg',
|
||||
rooted: '/logos/rooted.svg'
|
||||
}),
|
||||
{ headers: { 'content-type': 'application/json' } }
|
||||
)
|
||||
)
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.getLogoUrl('valid')).toBe('/templates/logos/valid.svg')
|
||||
expect(store.getLogoUrl('missing')).toBe('')
|
||||
expect(store.getLogoUrl('missingExtension')).toBe('')
|
||||
expect(store.getLogoUrl('parent')).toBe('')
|
||||
expect(store.getLogoUrl('rooted')).toBe('')
|
||||
})
|
||||
|
||||
it('ignores invalid and failed logo indexes', async () => {
|
||||
vi.mocked(fetch).mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ valid: 1 }), {
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
)
|
||||
const invalidStore = useWorkflowTemplatesStore()
|
||||
|
||||
await invalidStore.loadWorkflowTemplates()
|
||||
|
||||
expect(invalidStore.getLogoUrl('valid')).toBe('')
|
||||
|
||||
setActivePinia(createPinia())
|
||||
vi.mocked(fetch).mockRejectedValueOnce(new Error('logo failed'))
|
||||
const failedStore = useWorkflowTemplatesStore()
|
||||
|
||||
await failedStore.loadWorkflowTemplates()
|
||||
|
||||
expect(failedStore.getLogoUrl('valid')).toBe('')
|
||||
})
|
||||
|
||||
it('includes cloud-only templates and custom groups when requested', async () => {
|
||||
dist.isCloud = true
|
||||
coreResult.value = [
|
||||
coreCategory({
|
||||
title: 'Cloud Templates',
|
||||
templates: [
|
||||
{
|
||||
name: 'metadata-light',
|
||||
description: '',
|
||||
mediaType: 'image',
|
||||
mediaSubtype: 'webp',
|
||||
requiresCustomNodes: ['custom-node']
|
||||
}
|
||||
]
|
||||
})
|
||||
]
|
||||
customResult.value = { CustomPack: ['custom-flow'] }
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.enhancedTemplates.map((t) => t.name)).toEqual([
|
||||
'metadata-light',
|
||||
'custom-flow'
|
||||
])
|
||||
expect(
|
||||
store.groupedTemplates.find((group) => group.label === 'Custom Nodes')
|
||||
).toBeDefined()
|
||||
expect(store.getTemplateByName('metadata-light')?.searchableText).toBe(
|
||||
'metadata-light Cloud Templates'
|
||||
)
|
||||
})
|
||||
|
||||
it('omits optional nav sections when templates do not need them', async () => {
|
||||
coreResult.value = [
|
||||
coreCategory({
|
||||
templates: [{ ...baseTemplate, openSource: true }]
|
||||
})
|
||||
]
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
const items = store.navGroupedTemplates
|
||||
const flatItems = navItems(items)
|
||||
|
||||
expect(flatItems.map((item) => item.id)).toEqual(['all', 'popular'])
|
||||
expect(
|
||||
items.some((item) => 'title' in item && item.title === 'Extensions')
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('uses fallback icons for essential and grouped nav entries', async () => {
|
||||
coreResult.value = [
|
||||
coreCategory({
|
||||
title: 'Getting Started',
|
||||
isEssential: true,
|
||||
type: undefined
|
||||
}),
|
||||
coreCategory({
|
||||
title: 'Model Tools',
|
||||
category: 'MODEL TYPE',
|
||||
type: undefined
|
||||
})
|
||||
]
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
const flatItems = navItems(store.navGroupedTemplates)
|
||||
|
||||
expect(
|
||||
flatItems.find((item) => item.label === 'Getting Started')?.icon
|
||||
).toBe('icon-[lucide--graduation-cap]')
|
||||
expect(flatItems.find((item) => item.label === 'Model Tools')?.icon).toBe(
|
||||
'icon-[lucide--folder]'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns english metadata when cloud loads a non-english locale', async () => {
|
||||
dist.isCloud = true
|
||||
locale.value = 'fr'
|
||||
coreByLocale.value = {
|
||||
fr: [
|
||||
coreCategory({
|
||||
templates: [{ ...baseTemplate, name: 'localized', title: 'Localise' }]
|
||||
})
|
||||
],
|
||||
en: [
|
||||
coreCategory({
|
||||
title: 'English Category',
|
||||
templates: [
|
||||
{
|
||||
...baseTemplate,
|
||||
name: 'localized',
|
||||
tags: ['tag'],
|
||||
useCase: 'test',
|
||||
models: ['model'],
|
||||
license: 'MIT'
|
||||
}
|
||||
]
|
||||
})
|
||||
]
|
||||
}
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.getEnglishMetadata('localized')).toEqual({
|
||||
tags: ['tag'],
|
||||
category: 'English Category',
|
||||
useCase: 'test',
|
||||
models: ['model'],
|
||||
license: 'MIT'
|
||||
})
|
||||
expect(store.getEnglishMetadata('missing')).toBeNull()
|
||||
})
|
||||
|
||||
it('does not refetch once loaded', async () => {
|
||||
const store = useWorkflowTemplatesStore()
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
coreResult.value = []
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.knownTemplateNames.has('default')).toBe(true)
|
||||
})
|
||||
|
||||
it('returns null english metadata when no english templates are loaded', async () => {
|
||||
const store = useWorkflowTemplatesStore()
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.getEnglishMetadata('default')).toBeNull()
|
||||
})
|
||||
|
||||
it('reloads loaded templates when locale changes', async () => {
|
||||
coreByLocale.value = {
|
||||
en: [
|
||||
coreCategory({
|
||||
templates: [{ ...baseTemplate, name: 'english' }]
|
||||
})
|
||||
],
|
||||
fr: [
|
||||
coreCategory({
|
||||
templates: [{ ...baseTemplate, name: 'french' }]
|
||||
})
|
||||
]
|
||||
}
|
||||
const store = useWorkflowTemplatesStore()
|
||||
|
||||
locale.value = 'fr'
|
||||
await nextTick()
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
expect(store.knownTemplateNames.has('french')).toBe(true)
|
||||
|
||||
coreByLocale.value.es = [
|
||||
coreCategory({
|
||||
templates: [{ ...baseTemplate, name: 'spanish' }]
|
||||
})
|
||||
]
|
||||
locale.value = 'es'
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(store.knownTemplateNames.has('spanish')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('keeps existing templates when locale reload fails', async () => {
|
||||
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const store = useWorkflowTemplatesStore()
|
||||
await store.loadWorkflowTemplates()
|
||||
|
||||
coreErrorLocales.value.add('fr')
|
||||
locale.value = 'fr'
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'Error reloading templates for new locale:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
expect(store.knownTemplateNames.has('default')).toBe(true)
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user