mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-03 13:48:49 +00:00
Compare commits
9 Commits
shihchi/co
...
codex/cove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7141bda563 | ||
|
|
74d4366994 | ||
|
|
c85c0585ab | ||
|
|
804630fa02 | ||
|
|
7417864353 | ||
|
|
1f26dc2b57 | ||
|
|
2ec2a0e091 | ||
|
|
9cf5c9a93f | ||
|
|
9e5fb67b76 |
@@ -15,7 +15,7 @@ const { categories } = defineProps<{
|
||||
|
||||
const activeSection = ref(categories[0]?.value ?? '')
|
||||
|
||||
const HEADER_OFFSET = -144
|
||||
const HEADER_OFFSET_PX = -144
|
||||
const BOTTOM_THRESHOLD_PX = 4
|
||||
const SCROLL_SAFETY_MS = 1500
|
||||
|
||||
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
|
||||
const el = document.getElementById(id)
|
||||
if (el) {
|
||||
scrollTo(el, {
|
||||
offset: HEADER_OFFSET,
|
||||
offset: HEADER_OFFSET_PX,
|
||||
duration: 0.8,
|
||||
immediate: prefersReducedMotion(),
|
||||
onComplete: clearScrollLock
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<li
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
|
||||
>
|
||||
<slot />
|
||||
</li>
|
||||
|
||||
45
browser_tests/assets/linear-validation-warning.json
Normal file
45
browser_tests/assets/linear-validation-warning.json
Normal file
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"last_node_id": 9,
|
||||
"last_link_id": 9,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 9,
|
||||
"type": "SaveImage",
|
||||
"pos": {
|
||||
"0": 64,
|
||||
"1": 104
|
||||
},
|
||||
"size": {
|
||||
"0": 210,
|
||||
"1": 58
|
||||
},
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"properties": {},
|
||||
"widgets_values": ["ComfyUI"]
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 1,
|
||||
"offset": [0, 0]
|
||||
},
|
||||
"linearData": {
|
||||
"inputs": [],
|
||||
"outputs": ["9"]
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
@@ -34,6 +34,10 @@ export class AppModeHelper {
|
||||
public readonly outputPlaceholder: Locator
|
||||
/** The linear-mode widget list container (visible in app mode). */
|
||||
public readonly linearWidgets: Locator
|
||||
/** The validation warning shown above the app mode run button. */
|
||||
public readonly validationWarning: Locator
|
||||
/** The action that opens graph mode errors from the validation warning. */
|
||||
public readonly viewErrorsInGraphButton: Locator
|
||||
/** The PrimeVue Popover for the image picker (renders with role="dialog"). */
|
||||
public readonly imagePickerPopover: Locator
|
||||
/** The Run button in the app mode footer. */
|
||||
@@ -92,13 +96,19 @@ export class AppModeHelper {
|
||||
this.outputPlaceholder = this.page.getByTestId(
|
||||
TestIds.builder.outputPlaceholder
|
||||
)
|
||||
this.linearWidgets = this.page.getByTestId('linear-widgets')
|
||||
this.linearWidgets = this.page.getByTestId(TestIds.linear.widgetContainer)
|
||||
this.validationWarning = this.page.getByTestId(
|
||||
TestIds.linear.validationWarning
|
||||
)
|
||||
this.viewErrorsInGraphButton = this.validationWarning.getByTestId(
|
||||
TestIds.linear.viewErrorsInGraph
|
||||
)
|
||||
this.imagePickerPopover = this.page
|
||||
.getByRole('dialog')
|
||||
.filter({ has: this.page.getByRole('button', { name: 'All' }) })
|
||||
.first()
|
||||
this.runButton = this.page
|
||||
.getByTestId('linear-run-button')
|
||||
.getByTestId(TestIds.linear.runButton)
|
||||
.getByRole('button', { name: /run/i })
|
||||
this.welcome = this.page.getByTestId(TestIds.appMode.welcome)
|
||||
this.emptyWorkflowText = this.page.getByTestId(
|
||||
|
||||
@@ -172,6 +172,9 @@ export const TestIds = {
|
||||
mobileNavigation: 'linear-mobile-navigation',
|
||||
mobileWorkflows: 'linear-mobile-workflows',
|
||||
outputInfo: 'linear-output-info',
|
||||
runButton: 'linear-run-button',
|
||||
validationWarning: 'linear-validation-warning',
|
||||
viewErrorsInGraph: 'linear-view-errors',
|
||||
widgetContainer: 'linear-widgets'
|
||||
},
|
||||
builder: {
|
||||
|
||||
106
browser_tests/tests/appModeValidationWarning.spec.ts
Normal file
106
browser_tests/tests/appModeValidationWarning.spec.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import {
|
||||
comfyExpect as expect,
|
||||
comfyPageFixture as test
|
||||
} from '@e2e/fixtures/ComfyPage'
|
||||
import type { NodeError, PromptResponse } from '@/schemas/apiSchema'
|
||||
import { ExecutionHelper } from '@e2e/fixtures/helpers/ExecutionHelper'
|
||||
import { enableErrorsOverlay } from '@e2e/fixtures/helpers/ErrorsTabHelper'
|
||||
import { TestIds } from '@e2e/fixtures/selectors'
|
||||
|
||||
const SAVE_IMAGE_NODE_ID = '9'
|
||||
|
||||
function buildSaveImageRequiredInputError(): NodeError {
|
||||
return {
|
||||
class_type: 'SaveImage',
|
||||
dependent_outputs: [],
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Required input is missing: images',
|
||||
details: '',
|
||||
extra_info: { input_name: 'images' }
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
test.describe(
|
||||
'App mode validation warning',
|
||||
{ tag: ['@ui', '@workflow'] },
|
||||
() => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await enableErrorsOverlay(comfyPage)
|
||||
await comfyPage.workflow.loadWorkflow('linear-validation-warning')
|
||||
await comfyPage.appMode.toggleAppMode()
|
||||
await expect(comfyPage.appMode.linearWidgets).toBeVisible()
|
||||
})
|
||||
|
||||
test('opens graph errors from the app mode validation warning', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
await expect(comfyPage.appMode.validationWarning).toBeHidden()
|
||||
|
||||
const exec = new ExecutionHelper(comfyPage)
|
||||
await exec.mockValidationFailure({
|
||||
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
|
||||
})
|
||||
|
||||
await comfyPage.appMode.runButton.click()
|
||||
const appModeOverlay = comfyPage.appMode.centerPanel.getByTestId(
|
||||
TestIds.dialogs.errorOverlay
|
||||
)
|
||||
await expect(appModeOverlay).toBeHidden()
|
||||
|
||||
await expect(comfyPage.appMode.validationWarning).toBeVisible()
|
||||
await expect(comfyPage.appMode.validationWarning).toContainText(
|
||||
/Required input missing/i
|
||||
)
|
||||
await expect(comfyPage.appMode.viewErrorsInGraphButton).toBeVisible()
|
||||
|
||||
await comfyPage.appMode.viewErrorsInGraphButton.click()
|
||||
|
||||
await expect(comfyPage.appMode.linearWidgets).toBeHidden()
|
||||
await expect(
|
||||
comfyPage.page.getByTestId(TestIds.propertiesPanel.root)
|
||||
).toBeVisible()
|
||||
await expect(
|
||||
comfyPage.page.getByTestId(TestIds.propertiesPanel.errorsTab)
|
||||
).toBeVisible()
|
||||
})
|
||||
|
||||
test('keeps the app mode run button enabled when the warning is visible', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const exec = new ExecutionHelper(comfyPage)
|
||||
await exec.mockValidationFailure({
|
||||
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
|
||||
})
|
||||
|
||||
await comfyPage.appMode.runButton.click()
|
||||
await expect(comfyPage.appMode.validationWarning).toBeVisible()
|
||||
await expect(comfyPage.appMode.runButton).toBeEnabled()
|
||||
|
||||
let promptQueued = false
|
||||
const mockResponse: PromptResponse = {
|
||||
prompt_id: 'test-id',
|
||||
node_errors: {},
|
||||
error: ''
|
||||
}
|
||||
await comfyPage.page.route(
|
||||
'**/api/prompt',
|
||||
async (route) => {
|
||||
promptQueued = true
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
body: JSON.stringify(mockResponse)
|
||||
})
|
||||
},
|
||||
{ times: 1 }
|
||||
)
|
||||
|
||||
await comfyPage.appMode.runButton.click()
|
||||
|
||||
await expect.poll(() => promptQueued).toBe(true)
|
||||
})
|
||||
}
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage'
|
||||
@@ -15,9 +16,10 @@ test.describe('Graph', { tag: ['@smoke', '@canvas'] }, () => {
|
||||
await comfyPage.workflow.loadWorkflow('inputs/input_order_swap')
|
||||
await expect
|
||||
.poll(() =>
|
||||
comfyPage.page.evaluate(() => {
|
||||
return window.app!.graph!.links.get(1)?.target_slot
|
||||
})
|
||||
comfyPage.page.evaluate(
|
||||
(linkId) => window.app!.graph!.links.get(linkId)?.target_slot,
|
||||
toLinkId(1)
|
||||
)
|
||||
)
|
||||
.toBe(1)
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
comfyPageFixture as test,
|
||||
comfyExpect as expect
|
||||
} from '@e2e/fixtures/ComfyPage'
|
||||
import { TestIds } from '@e2e/fixtures/selectors'
|
||||
|
||||
test.describe('Linear Mode', { tag: '@ui' }, () => {
|
||||
test('Displays linear controls when app mode active', async ({
|
||||
@@ -16,7 +17,9 @@ test.describe('Linear Mode', { tag: '@ui' }, () => {
|
||||
test('Run button visible in linear mode', async ({ comfyPage }) => {
|
||||
await comfyPage.appMode.enterAppModeWithInputs([])
|
||||
|
||||
await expect(comfyPage.page.getByTestId('linear-run-button')).toBeVisible()
|
||||
await expect(
|
||||
comfyPage.page.getByTestId(TestIds.linear.runButton)
|
||||
).toBeVisible()
|
||||
})
|
||||
|
||||
test('Workflow info section visible', async ({ comfyPage }) => {
|
||||
|
||||
75
src/base/common/async.test.ts
Normal file
75
src/base/common/async.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
describe('runWhenGlobalIdle', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('falls back to a timeout when idle callbacks are unavailable', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.stubGlobal('requestIdleCallback', undefined)
|
||||
vi.stubGlobal('cancelIdleCallback', undefined)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
const disposable = runWhenGlobalIdle(runner)
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
expect(runner).toHaveBeenCalledOnce()
|
||||
const deadline = runner.mock.calls[0][0]
|
||||
expect(deadline.didTimeout).toBe(true)
|
||||
expect(deadline.timeRemaining()).toBeGreaterThanOrEqual(0)
|
||||
|
||||
disposable.dispose()
|
||||
disposable.dispose()
|
||||
})
|
||||
|
||||
it('cancels fallback idle work before it runs', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.stubGlobal('requestIdleCallback', undefined)
|
||||
vi.stubGlobal('cancelIdleCallback', undefined)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
runWhenGlobalIdle(runner).dispose()
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
expect(runner).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses native idle callbacks when available', async () => {
|
||||
const requestIdleCallback = vi.fn(() => 42)
|
||||
const cancelIdleCallback = vi.fn()
|
||||
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
|
||||
vi.stubGlobal('cancelIdleCallback', cancelIdleCallback)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
const disposable = runWhenGlobalIdle(runner, 250)
|
||||
|
||||
expect(requestIdleCallback).toHaveBeenCalledWith(runner, { timeout: 250 })
|
||||
|
||||
disposable.dispose()
|
||||
disposable.dispose()
|
||||
|
||||
expect(cancelIdleCallback).toHaveBeenCalledOnce()
|
||||
expect(cancelIdleCallback).toHaveBeenCalledWith(42)
|
||||
})
|
||||
|
||||
it('omits native idle timeout options when no timeout is supplied', async () => {
|
||||
const requestIdleCallback = vi.fn(() => 7)
|
||||
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
|
||||
vi.stubGlobal('cancelIdleCallback', vi.fn())
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
runWhenGlobalIdle(runner)
|
||||
|
||||
expect(requestIdleCallback).toHaveBeenCalledWith(runner, undefined)
|
||||
})
|
||||
})
|
||||
@@ -122,6 +122,22 @@ describe('downloadUtil', () => {
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws for an empty URL', () => {
|
||||
expect(() => downloadFile('')).toThrow(
|
||||
'Invalid URL provided for download'
|
||||
)
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws for a whitespace URL', () => {
|
||||
expect(() => downloadFile(' ')).toThrow(
|
||||
'Invalid URL provided for download'
|
||||
)
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should prefer custom filename over extracted filename', () => {
|
||||
const testUrl =
|
||||
'https://example.com/api/file?filename=extracted-image.jpg'
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
CREDITS_PER_USD,
|
||||
COMFY_CREDIT_RATE_CENTS,
|
||||
centsToCredits,
|
||||
clampUsd,
|
||||
creditsToCents,
|
||||
creditsToUsd,
|
||||
formatCredits,
|
||||
@@ -43,4 +44,21 @@ describe('comfyCredits helpers', () => {
|
||||
expect(formatCreditsFromUsd({ usd: 1, locale })).toBe('211.00')
|
||||
expect(formatUsd({ value: 4.2, locale })).toBe('4.20')
|
||||
})
|
||||
|
||||
test('formats with compatible fraction digit bounds', () => {
|
||||
expect(
|
||||
formatCredits({
|
||||
value: 12.345,
|
||||
locale: 'en-US',
|
||||
numberOptions: { minimumFractionDigits: 4, maximumFractionDigits: 2 }
|
||||
})
|
||||
).toBe('12.35')
|
||||
})
|
||||
|
||||
test('clamps USD purchase values into the supported range', () => {
|
||||
expect(clampUsd(Number.NaN)).toBe(0)
|
||||
expect(clampUsd(-5)).toBe(1)
|
||||
expect(clampUsd(42)).toBe(42)
|
||||
expect(clampUsd(5000)).toBe(1000)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
size="unset"
|
||||
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
|
||||
data-testid="error-overlay-see-errors"
|
||||
@click="seeErrors"
|
||||
@click="viewErrorsInGraph"
|
||||
>
|
||||
{{
|
||||
appMode
|
||||
@@ -67,31 +67,18 @@ import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
|
||||
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
|
||||
|
||||
const { appMode = false } = defineProps<{ appMode?: boolean }>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
const canvasStore = useCanvasStore()
|
||||
const { viewErrorsInGraph } = useViewErrorsInGraph()
|
||||
|
||||
const { isVisible, overlayMessage, overlayTitle } = useErrorOverlayState()
|
||||
|
||||
function dismiss() {
|
||||
executionErrorStore.dismissErrorOverlay()
|
||||
}
|
||||
|
||||
function seeErrors() {
|
||||
canvasStore.linearMode = false
|
||||
if (canvasStore.canvas) {
|
||||
canvasStore.canvas.deselectAll()
|
||||
canvasStore.updateSelectedItems()
|
||||
}
|
||||
|
||||
rightSidePanelStore.openPanel('errors')
|
||||
executionErrorStore.dismissErrorOverlay()
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
|
||||
const subscriptionDialog = useSubscriptionDialog()
|
||||
|
||||
function handleClick() {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ComputedRef, Ref } from 'vue'
|
||||
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type {
|
||||
BillingStatus,
|
||||
@@ -75,9 +76,10 @@ export interface BillingActions {
|
||||
*/
|
||||
requireActiveSubscription: () => Promise<void>
|
||||
/**
|
||||
* Shows the subscription dialog.
|
||||
* Shows the subscription dialog. Pass a reason so the paywall open and any
|
||||
* downstream checkout stay attributed to the triggering product moment.
|
||||
*/
|
||||
showSubscriptionDialog: () => void
|
||||
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
|
||||
}
|
||||
|
||||
export interface BillingState {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
getTierFeatures
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
PreviewSubscribeOptions,
|
||||
SubscribeOptions
|
||||
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
|
||||
return activeContext.value.requireActiveSubscription()
|
||||
}
|
||||
|
||||
function showSubscriptionDialog() {
|
||||
return activeContext.value.showSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
|
||||
return activeContext.value.showSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
|
||||
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingStatus,
|
||||
BillingSubscriptionStatus,
|
||||
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
legacyShowSubscriptionDialog()
|
||||
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
legacyShowSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
legacyShowSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
105
src/composables/useViewErrorsInGraph.test.ts
Normal file
105
src/composables/useViewErrorsInGraph.test.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
import { LGraph, LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { createMockCanvasRenderingContext2D } from '@/utils/__tests__/litegraphTestUtils'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
|
||||
import { useViewErrorsInGraph } from './useViewErrorsInGraph'
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getSettings: vi.fn(),
|
||||
storeSetting: vi.fn(),
|
||||
storeSettings: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: apiMock
|
||||
}))
|
||||
|
||||
const appMock = vi.hoisted(() => ({
|
||||
ui: {
|
||||
settings: {
|
||||
dispatchChange: vi.fn()
|
||||
}
|
||||
},
|
||||
rootGraph: {
|
||||
events: new EventTarget(),
|
||||
nodes: []
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: appMock
|
||||
}))
|
||||
|
||||
function createSelectedCanvas() {
|
||||
const graph = new LGraph()
|
||||
const canvasElement = document.createElement('canvas')
|
||||
canvasElement.width = 800
|
||||
canvasElement.height = 600
|
||||
canvasElement.getContext = vi
|
||||
.fn()
|
||||
.mockReturnValue(createMockCanvasRenderingContext2D())
|
||||
|
||||
const canvas = new LGraphCanvas(canvasElement, graph, {
|
||||
skip_events: true,
|
||||
skip_render: true
|
||||
})
|
||||
const node = new LGraphNode('Selected Node')
|
||||
graph.add(node)
|
||||
canvas.selectedItems.add(node)
|
||||
node.selected = true
|
||||
|
||||
return { canvas, node }
|
||||
}
|
||||
|
||||
describe('useViewErrorsInGraph', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setActivePinia(createPinia())
|
||||
apiMock.getSettings.mockResolvedValue({})
|
||||
apiMock.storeSetting.mockResolvedValue(undefined)
|
||||
apiMock.storeSettings.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('opens graph errors and clears app-mode error UI state', () => {
|
||||
const canvasStore = useCanvasStore()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { canvas, node } = createSelectedCanvas()
|
||||
workflowStore.activeWorkflow = {
|
||||
activeMode: 'app'
|
||||
} as typeof workflowStore.activeWorkflow
|
||||
canvasStore.canvas = canvas
|
||||
canvasStore.selectedItems = [node]
|
||||
executionErrorStore.showErrorOverlay()
|
||||
|
||||
useViewErrorsInGraph().viewErrorsInGraph()
|
||||
|
||||
expect(node.selected).toBe(false)
|
||||
expect(canvasStore.linearMode).toBe(false)
|
||||
expect(canvasStore.selectedItems).toEqual([])
|
||||
expect(rightSidePanelStore.activeTab).toBe('errors')
|
||||
expect(rightSidePanelStore.isOpen).toBe(true)
|
||||
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('opens graph errors when the canvas is not initialized', () => {
|
||||
const canvasStore = useCanvasStore()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
canvasStore.canvas = null
|
||||
executionErrorStore.showErrorOverlay()
|
||||
|
||||
expect(() => useViewErrorsInGraph().viewErrorsInGraph()).not.toThrow()
|
||||
|
||||
expect(rightSidePanelStore.activeTab).toBe('errors')
|
||||
expect(rightSidePanelStore.isOpen).toBe(true)
|
||||
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
})
|
||||
22
src/composables/useViewErrorsInGraph.ts
Normal file
22
src/composables/useViewErrorsInGraph.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
|
||||
export function useViewErrorsInGraph() {
|
||||
const canvasStore = useCanvasStore()
|
||||
const executionErrorStore = useExecutionErrorStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
|
||||
function viewErrorsInGraph() {
|
||||
canvasStore.linearMode = false
|
||||
if (canvasStore.canvas) {
|
||||
canvasStore.canvas.deselectAll()
|
||||
canvasStore.updateSelectedItems()
|
||||
}
|
||||
|
||||
rightSidePanelStore.openPanel('errors')
|
||||
executionErrorStore.dismissErrorOverlay()
|
||||
}
|
||||
|
||||
return { viewErrorsInGraph }
|
||||
}
|
||||
@@ -25,6 +25,6 @@ function handleClose() {
|
||||
}
|
||||
|
||||
function handleSubscribe() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
|
||||
// Shows loading affordances
|
||||
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
|
||||
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'team_700',
|
||||
'yearly'
|
||||
'yearly',
|
||||
{ paymentIntentSource: 'deep_link' }
|
||||
)
|
||||
// Team never goes through the personal checkout path
|
||||
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()
|
||||
|
||||
@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
return
|
||||
}
|
||||
isTeamCheckout.value = true
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle)
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle, {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
if (isActiveSubscription.value) {
|
||||
await accessBillingPortal(undefined, false)
|
||||
} else {
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
}
|
||||
}, reportError)
|
||||
|
||||
|
||||
@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
|
||||
})
|
||||
|
||||
function handleAddCredits() {
|
||||
telemetry?.trackAddApiCreditButtonClicked()
|
||||
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
function handleUpgradeToAddCredits() {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
}
|
||||
|
||||
async function handleWindowFocus() {
|
||||
|
||||
@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import FreeTierDialogContent from './FreeTierDialogContent.vue'
|
||||
|
||||
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
|
||||
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
function renderComponent() {
|
||||
function renderComponent(props?: { reason?: PaymentIntentSource }) {
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
@@ -23,6 +25,7 @@ function renderComponent() {
|
||||
})
|
||||
|
||||
return render(FreeTierDialogContent, {
|
||||
props,
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
|
||||
renderComponent()
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps the generic copy for intent reasons outside the credits variants', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'subscribe_to_run' })
|
||||
expect(
|
||||
screen.getByText('Your credits refresh on Jul 15, 2026.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('swaps to the out-of-credits copy without the refresh line', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="!reason || reason === 'subscription_required'"
|
||||
v-if="!isCreditsBlockedVariant"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -65,10 +65,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="
|
||||
(!reason || reason === 'subscription_required') &&
|
||||
formattedRenewalDate
|
||||
"
|
||||
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -88,7 +85,7 @@
|
||||
@click="$emit('upgrade')"
|
||||
>
|
||||
{{
|
||||
reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
isCreditsBlockedVariant
|
||||
? $t('subscription.freeTier.upgradeCta')
|
||||
: $t('subscription.freeTier.subscribeCta')
|
||||
}}
|
||||
@@ -103,12 +100,12 @@ import { computed } from 'vue'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
|
||||
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
|
||||
defineProps<{
|
||||
reason?: SubscriptionDialogReason
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
defineEmits<{
|
||||
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
|
||||
})
|
||||
|
||||
const freeTierCredits = computed(() => getTierCredits('free'))
|
||||
|
||||
// Only these two variants replace the generic free-tier copy; any other
|
||||
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
|
||||
const isCreditsBlockedVariant = computed(
|
||||
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
)
|
||||
</script>
|
||||
|
||||
@@ -261,6 +261,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
|
||||
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use the latest userId value when it changes after mount', async () => {
|
||||
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
})
|
||||
|
||||
@@ -277,13 +277,19 @@ import type {
|
||||
TierKey,
|
||||
TierPricing
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import {
|
||||
recordPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
|
||||
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
@@ -321,6 +327,10 @@ interface PricingTierConfig {
|
||||
isPopular?: boolean
|
||||
}
|
||||
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
chooseTeamWorkspace: []
|
||||
}>()
|
||||
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
} as const
|
||||
const previousPlan = currentPlanDescriptor.value
|
||||
const checkoutAttribution = await getCheckoutAttributionForCloud()
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
})
|
||||
}
|
||||
const beginCheckoutMetadata = userId.value
|
||||
? {
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change' as const,
|
||||
...(reason ? { payment_intent_source: reason } : {}),
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
}
|
||||
: null
|
||||
// Pass the target tier to create a deep link to subscription update confirmation
|
||||
const checkoutTier = getCheckoutTier(
|
||||
targetPlan.tierKey,
|
||||
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
|
||||
if (downgrade) {
|
||||
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
|
||||
await accessBillingPortal()
|
||||
const didOpenPortal = await accessBillingPortal()
|
||||
if (didOpenPortal && beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
|
||||
}
|
||||
} else {
|
||||
const didOpenPortal = await accessBillingPortal(checkoutTier)
|
||||
if (!didOpenPortal) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
payment_intent_source: reason,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
|
||||
...(previousPlan
|
||||
? { previous_cycle: previousPlan.billingCycle }
|
||||
: {})
|
||||
})
|
||||
if (beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
beginCheckoutMetadata,
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await performSubscriptionCheckout(
|
||||
tierKey,
|
||||
currentBillingCycle.value,
|
||||
true
|
||||
)
|
||||
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
|
||||
paymentIntentSource: reason
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
|
||||
@@ -56,7 +56,7 @@ const handleSubscribe = () => {
|
||||
current_tier: tier.value?.toLowerCase()
|
||||
})
|
||||
isAwaitingStripeSubscription.value = true
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
|
||||
@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
|
||||
trackRunButton({ subscribe_to_run: true })
|
||||
}
|
||||
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -48,7 +48,9 @@
|
||||
v-if="isActiveSubscription"
|
||||
variant="primary"
|
||||
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
|
||||
@click="showSubscriptionDialog"
|
||||
@click="
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
"
|
||||
>
|
||||
{{ $t('subscription.upgradePlan') }}
|
||||
</Button>
|
||||
|
||||
@@ -33,7 +33,11 @@
|
||||
</i18n-t>
|
||||
</div>
|
||||
|
||||
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
|
||||
<PricingTable
|
||||
:reason
|
||||
class="flex-1"
|
||||
@choose-team-workspace="handleChooseTeam"
|
||||
/>
|
||||
|
||||
<!-- Contact and Enterprise Links -->
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
const { onClose, reason, onChooseTeam } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
onChooseTeam?: () => void
|
||||
}>()
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
|
||||
)
|
||||
return
|
||||
case 'subscription':
|
||||
void dialogService.showSubscriptionRequiredDialog()
|
||||
void dialogService.showSubscriptionRequiredDialog({
|
||||
reason: 'subscription_required'
|
||||
})
|
||||
return
|
||||
case 'credits':
|
||||
void dialogService.showTopUpCreditsDialog({
|
||||
|
||||
@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
describe('usePricingTableUrlLoader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
reason: 'deep_link',
|
||||
planMode: undefined
|
||||
})
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
reason: 'deep_link'
|
||||
})
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
})
|
||||
|
||||
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('denies, strips, and clears together when the user is not eligible', async () => {
|
||||
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({
|
||||
query: { other: 'param' }
|
||||
})
|
||||
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
)
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
|
||||
'pricing'
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
mergePreservedQueryIntoQuery
|
||||
} from '@/platform/navigation/preservedQueryManager'
|
||||
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
|
||||
const planMode =
|
||||
param === 'team' || param === 'personal' ? param : undefined
|
||||
|
||||
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
|
||||
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
|
||||
})
|
||||
}, reportError)
|
||||
|
||||
const showSubscriptionDialog = (options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) => {
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: subscriptionTier.value?.toLowerCase(),
|
||||
reason: options?.reason
|
||||
})
|
||||
|
||||
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
|
||||
void showSubscriptionRequiredDialog(options)
|
||||
}
|
||||
|
||||
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
|
||||
await fetchSubscriptionStatus()
|
||||
|
||||
if (!isSubscribedOrIsNotCloud.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
|
||||
}))
|
||||
|
||||
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
|
||||
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
|
||||
const {
|
||||
mockIsCloud,
|
||||
mockTrackHelpResourceClicked,
|
||||
mockTrackAddApiCreditButtonClicked
|
||||
} = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockTrackHelpResourceClicked: vi.fn()
|
||||
mockTrackHelpResourceClicked: vi.fn(),
|
||||
mockTrackAddApiCreditButtonClicked: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () =>
|
||||
mockIsCloud.value
|
||||
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
|
||||
? {
|
||||
trackHelpResourceClicked: mockTrackHelpResourceClicked,
|
||||
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
|
||||
}
|
||||
: null
|
||||
}))
|
||||
|
||||
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
|
||||
const { handleAddApiCredits } = useSubscriptionActions()
|
||||
handleAddApiCredits()
|
||||
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
|
||||
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
|
||||
})
|
||||
|
||||
const handleAddApiCredits = () => {
|
||||
telemetry?.trackAddApiCreditButtonClicked({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockShowLayoutDialog = vi.fn()
|
||||
const mockShowTeamWorkspacesDialog = vi.fn()
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
|
||||
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
|
||||
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
|
||||
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isFreeTier: mockIsFreeTier,
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan,
|
||||
tier: mockTier
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
|
||||
useWorkspaceUI: () => ({
|
||||
permissions: {
|
||||
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
|
||||
mockIsCloud.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
mockIsFreeTier.value = false
|
||||
mockTier.value = 'FREE'
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
mockCanManageSubscription.value = true
|
||||
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
|
||||
const props = mockShowLayoutDialog.mock.calls[0][0].props
|
||||
expect(props.initialPlanMode).toBe('team')
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the caller reason and current tier', () => {
|
||||
mockTier.value = 'STANDARD'
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
current_tier: 'standard',
|
||||
reason: 'upgrade_to_add_credits'
|
||||
})
|
||||
})
|
||||
|
||||
it('tracks modal_opened on the workspace (unified) path too', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'subscribe_to_run' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not track modal_opened for the inactive member dialog', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockCanManageSubscription.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not track on non-cloud', () => {
|
||||
mockIsCloud.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('show', () => {
|
||||
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
|
||||
expect.objectContaining({ key: 'subscription-required' })
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the reason for the free-tier dialog', () => {
|
||||
mockIsFreeTier.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { show } = useSubscriptionDialog()
|
||||
|
||||
show({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'out_of_credits' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startTeamWorkspaceUpgradeFlow', () => {
|
||||
|
||||
@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
|
||||
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
|
||||
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
|
||||
|
||||
export type SubscriptionDialogReason =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
|
||||
interface SubscriptionDialogOptions {
|
||||
reason?: SubscriptionDialogReason
|
||||
export interface SubscriptionDialogOptions {
|
||||
reason?: PaymentIntentSource
|
||||
/**
|
||||
* Forces the unified pricing dialog to open on a specific plan tab,
|
||||
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
|
||||
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
|
||||
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
|
||||
}
|
||||
|
||||
// Fired here — the choke point every paywall/pricing dialog variant passes
|
||||
// through — so both the legacy and workspace billing paths emit it.
|
||||
function trackModalOpened(reason?: PaymentIntentSource) {
|
||||
// Resolved lazily to avoid the useBillingContext import cycle (see below).
|
||||
const { tier } = useBillingContext()
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: tier.value?.toLowerCase(),
|
||||
reason
|
||||
})
|
||||
}
|
||||
|
||||
function showPricingTable(options?: SubscriptionDialogOptions) {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
|
||||
return
|
||||
}
|
||||
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
// Shared dialog shell styling for both variants.
|
||||
const dialogComponentProps = {
|
||||
style: 'width: min(1328px, 95vw); max-height: 958px;',
|
||||
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
|
||||
// (not at composable setup) to avoid the useBillingContext import cycle.
|
||||
const { isFreeTier } = useBillingContext()
|
||||
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
const component = defineAsyncComponent(
|
||||
() =>
|
||||
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
|
||||
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
|
||||
sessionStorage.removeItem(RESUME_PRICING_KEY)
|
||||
|
||||
if (!workspaceStore.isInPersonalWorkspace) {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'team_upgrade_resume' })
|
||||
}
|
||||
} catch {
|
||||
// sessionStorage may be unavailable
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
clearPendingSubscriptionCheckoutAttempt,
|
||||
consumePendingSubscriptionCheckoutSuccess,
|
||||
recordPendingSubscriptionCheckoutAttempt
|
||||
} from './subscriptionCheckoutTracker'
|
||||
|
||||
const activeProStatus = {
|
||||
is_active: true,
|
||||
subscription_tier: 'PRO',
|
||||
subscription_duration: 'MONTHLY'
|
||||
} as const
|
||||
|
||||
describe('subscriptionCheckoutTracker', () => {
|
||||
beforeEach(() => {
|
||||
clearPendingSubscriptionCheckoutAttempt()
|
||||
})
|
||||
|
||||
it('round-trips payment_intent_source from attempt to success metadata', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).toMatchObject({
|
||||
tier: 'pro',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('omits payment_intent_source when the attempt had none', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).not.toBeNull()
|
||||
expect(metadata).not.toHaveProperty('payment_intent_source')
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,12 @@ import type {
|
||||
TierKey
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
BeginCheckoutMetadata,
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType,
|
||||
SubscriptionSuccessMetadata
|
||||
} from '@/platform/telemetry/types'
|
||||
|
||||
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
|
||||
const VALID_TIER_KEYS = new Set<TierKey>([
|
||||
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
|
||||
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
|
||||
'comfy:subscription-checkout-attempt-changed'
|
||||
|
||||
type CheckoutType = 'new' | 'change'
|
||||
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
|
||||
|
||||
interface SubscriptionStatusSnapshot {
|
||||
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
|
||||
subscription_duration?: SubscriptionDuration | null
|
||||
}
|
||||
|
||||
interface PendingSubscriptionCheckoutAttempt {
|
||||
export interface PendingSubscriptionCheckoutAttempt {
|
||||
attempt_id: string
|
||||
started_at_ms: number
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface RecordPendingSubscriptionCheckoutAttemptInput {
|
||||
interface PendingSubscriptionCheckoutAttemptInput {
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
const dispatchPendingCheckoutChangeEvent = () => {
|
||||
@@ -168,6 +174,9 @@ const normalizeAttempt = (
|
||||
...(candidate.previous_cycle === 'monthly' ||
|
||||
candidate.previous_cycle === 'yearly'
|
||||
? { previous_cycle: candidate.previous_cycle }
|
||||
: {}),
|
||||
...(typeof candidate.payment_intent_source === 'string'
|
||||
? { payment_intent_source: candidate.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
|
||||
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
|
||||
getPendingSubscriptionCheckoutAttempt() !== null
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: RecordPendingSubscriptionCheckoutAttemptInput
|
||||
export const createPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
const attempt: PendingSubscriptionCheckoutAttempt = {
|
||||
return {
|
||||
attempt_id: createAttemptId(),
|
||||
started_at_ms: Date.now(),
|
||||
tier: input.tier,
|
||||
cycle: input.cycle,
|
||||
checkout_type: input.checkout_type,
|
||||
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
|
||||
...(input.payment_intent_source
|
||||
? { payment_intent_source: input.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
|
||||
export const persistPendingSubscriptionCheckoutAttempt = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
if (!storage) {
|
||||
return attempt
|
||||
}
|
||||
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
return attempt
|
||||
}
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt =>
|
||||
persistPendingSubscriptionCheckoutAttempt(
|
||||
createPendingSubscriptionCheckoutAttempt(input)
|
||||
)
|
||||
|
||||
export const withPendingCheckoutAttemptId = (
|
||||
metadata: BeginCheckoutMetadata,
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): BeginCheckoutMetadata => ({
|
||||
...metadata,
|
||||
checkout_attempt_id: attempt.attempt_id
|
||||
})
|
||||
|
||||
const didAttemptSucceed = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt,
|
||||
status: SubscriptionStatusSnapshot
|
||||
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
|
||||
cycle: attempt.cycle,
|
||||
checkout_type: attempt.checkout_type,
|
||||
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
|
||||
...(attempt.payment_intent_source
|
||||
? { payment_intent_source: attempt.payment_intent_source }
|
||||
: {}),
|
||||
value,
|
||||
currency: 'USD',
|
||||
ecommerce: {
|
||||
|
||||
@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'yearly', true)
|
||||
await performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
ga_client_id: 'ga-client-id',
|
||||
ga_session_id: 'ga-session-id',
|
||||
ga_session_number: 'ga-session-number',
|
||||
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
|
||||
gbraid: 'gbraid-456',
|
||||
wbraid: 'wbraid-789'
|
||||
})
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
JSON.parse(storedAttempt).attempt_id
|
||||
)
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'/customers/cloud-subscription-checkout/pro-yearly'
|
||||
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[SubscriptionCheckout] Failed to collect checkout attribution',
|
||||
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
.spyOn(window, 'open')
|
||||
.mockImplementation(() => window as unknown as Window)
|
||||
|
||||
vi.mocked(global.fetch).mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', {
|
||||
paymentIntentSource: 'out_of_credits'
|
||||
})
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
|
||||
)
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
const pendingAttempt = JSON.parse(storedAttempt)
|
||||
expect(pendingAttempt).toMatchObject({
|
||||
payment_intent_source: 'out_of_credits'
|
||||
})
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
pendingAttempt.attempt_id
|
||||
)
|
||||
openSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('uses the latest userId when it changes after checkout starts', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
mockUserId.value = 'user-late'
|
||||
authHeader.resolve({ Authorization: 'Bearer test-token' })
|
||||
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-late',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
|
||||
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
|
||||
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
const storedAttempt = window.localStorage.getItem(
|
||||
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
|
||||
)
|
||||
expect(storedAttempt).toBeNull()
|
||||
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { getComfyApiBaseUrl } from '@/config/comfyApi'
|
||||
import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import {
|
||||
createPendingSubscriptionCheckoutAttempt,
|
||||
persistPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
type CheckoutTier = TierKey | `${TierKey}-yearly`
|
||||
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
|
||||
return getCheckoutAttribution()
|
||||
}
|
||||
|
||||
interface PerformSubscriptionCheckoutOptions {
|
||||
openInNewTab?: boolean
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Core subscription checkout logic shared between PricingTable and
|
||||
* SubscriptionRedirectView. Handles:
|
||||
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
|
||||
export async function performSubscriptionCheckout(
|
||||
tierKey: TierKey,
|
||||
currentBillingCycle: BillingCycle,
|
||||
openInNewTab: boolean = true
|
||||
options: PerformSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
const { openInNewTab = true, paymentIntentSource } = options
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { userId } = storeToRefs(authStore)
|
||||
const telemetry = useTelemetry()
|
||||
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
|
||||
const data = await response.json()
|
||||
|
||||
if (data.checkout_url) {
|
||||
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: paymentIntentSource
|
||||
})
|
||||
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...checkoutAttribution
|
||||
})
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
{
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {}),
|
||||
...checkoutAttribution
|
||||
},
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if (openInNewTab) {
|
||||
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
|
||||
if (!checkoutWindow) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
} else {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
globalThis.location.href = data.checkout_url
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn()
|
||||
}))
|
||||
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
|
||||
vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
|
||||
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
|
||||
workspaceApi: { subscribe: mockSubscribe }
|
||||
}))
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
|
||||
}))
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
|
||||
|
||||
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
billing_op_id: 'op_1'
|
||||
})
|
||||
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly', {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
|
||||
returnUrl: 'https://app.test/payment/success',
|
||||
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
teamCreditStopId: 'team_700'
|
||||
})
|
||||
expect(assignedHref).toBe('https://stripe.test/pay')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'team',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op_1',
|
||||
payment_intent_source: 'deep_link'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
|
||||
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
expect(assignedHref).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not track begin_checkout when subscribe fails', async () => {
|
||||
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
|
||||
|
||||
await expect(
|
||||
performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
).rejects.toThrow('subscribe failed')
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does nothing off cloud', async () => {
|
||||
mockIsCloud.value = false
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
|
||||
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
interface PerformTeamSubscriptionCheckoutOptions {
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
|
||||
* link: subscribes to the per-credit Team plan at the chosen slider stop and
|
||||
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
|
||||
*/
|
||||
export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId: string,
|
||||
billingCycle: BillingCycle
|
||||
billingCycle: BillingCycle,
|
||||
options: PerformTeamSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId
|
||||
})
|
||||
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType: 'new',
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource: options.paymentIntentSource
|
||||
})
|
||||
|
||||
if (response.status === 'needs_payment_method') {
|
||||
// A needs_payment_method response without a URL is unusable: surface it to
|
||||
// the caller's error handling rather than silently dropping the user home
|
||||
|
||||
@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
|
||||
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
})
|
||||
|
||||
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
|
||||
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
|
||||
const b: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const metadata = {
|
||||
user_id: 'user-1',
|
||||
tier: 'pro' as const,
|
||||
cycle: 'monthly' as const,
|
||||
checkout_type: 'new' as const,
|
||||
payment_intent_source: 'subscribe_to_run' as const
|
||||
}
|
||||
registry.trackBeginCheckout(metadata)
|
||||
|
||||
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
|
||||
})
|
||||
|
||||
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
|
||||
const provider: TelemetryProvider = {
|
||||
trackAddApiCreditButtonClicked: vi.fn()
|
||||
}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(provider)
|
||||
|
||||
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(
|
||||
provider.trackAddApiCreditButtonClicked
|
||||
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
|
||||
})
|
||||
|
||||
it('skips providers that do not implement trackSearchQuery', () => {
|
||||
const empty: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
|
||||
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackAddApiCreditButtonClicked?.(metadata)
|
||||
)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('captures begin_checkout with intent metadata', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackBeginCheckout({
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.BEGIN_CHECKOUT,
|
||||
{
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('captures add-credit clicks with their source', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'credits_panel' }
|
||||
)
|
||||
})
|
||||
|
||||
it('captures share attribution events', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
EnterLinearMetadata,
|
||||
ShareFlowMetadata,
|
||||
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
|
||||
this.trackEvent(eventName, metadata)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
|
||||
}
|
||||
|
||||
trackMonthlySubscriptionSucceeded(
|
||||
|
||||
@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards add-credit clicks with their source', () => {
|
||||
new HostTelemetrySink().trackAddApiCreditButtonClicked({
|
||||
source: 'avatar_menu'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'avatar_menu' }
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when the host bridge is absent', () => {
|
||||
delete window.__comfyDesktop2
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
|
||||
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -12,12 +12,29 @@
|
||||
* 3. Check dist/assets/*.js files contain no tracking code
|
||||
*/
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
import type { AppMode } from '@/utils/appMode'
|
||||
|
||||
export type PaymentIntentSource =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
| 'subscribe_to_run'
|
||||
| 'subscribe_now_button'
|
||||
| 'upgrade_to_add_credits'
|
||||
| 'settings_billing_panel'
|
||||
| 'avatar_menu_plans'
|
||||
| 'team_members_panel'
|
||||
| 'invite_member_upsell'
|
||||
| 'upload_model_upgrade'
|
||||
| 'team_upgrade_resume'
|
||||
|
||||
export type SubscriptionCheckoutType = 'new' | 'change'
|
||||
export type SubscriptionCheckoutTier = TierKey | 'team'
|
||||
|
||||
/**
|
||||
* Authentication metadata for sign-up tracking
|
||||
*/
|
||||
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
|
||||
|
||||
export interface SubscriptionMetadata {
|
||||
current_tier?: string
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export interface AddCreditsClickMetadata {
|
||||
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
|
||||
}
|
||||
|
||||
export interface BeginCheckoutMetadata
|
||||
extends Record<string, unknown>, CheckoutAttributionMetadata {
|
||||
user_id: string
|
||||
tier: TierKey
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
checkout_attempt_id?: string
|
||||
billing_op_id?: string
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface EcommerceItemMetadata {
|
||||
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
|
||||
checkout_attempt_id: string
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
value: number
|
||||
currency: string
|
||||
ecommerce: EcommerceMetadata
|
||||
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
|
||||
metadata?: SubscriptionSuccessMetadata
|
||||
): void
|
||||
trackMonthlySubscriptionCancelled?(): void
|
||||
trackAddApiCreditButtonClicked?(): void
|
||||
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
|
||||
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
|
||||
trackApiCreditTopupSucceeded?(): void
|
||||
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void
|
||||
|
||||
@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -391,12 +391,13 @@ const showZeroState = computed(
|
||||
)
|
||||
|
||||
function handleSubscribeWorkspace() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleUpgrade() {
|
||||
if (isFreeTierPlan.value) showPricingTable()
|
||||
else showSubscriptionDialog()
|
||||
if (isFreeTierPlan.value)
|
||||
showPricingTable({ reason: 'settings_billing_panel' })
|
||||
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleViewMoreDetails() {
|
||||
|
||||
@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { useEventListener } from '@vueuse/core'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
|
||||
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
|
||||
|
||||
const { onClose, reason, initialPlanMode } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
initialPlanMode?: 'personal' | 'team'
|
||||
}>()
|
||||
|
||||
@@ -152,7 +152,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleTeamSubscribe,
|
||||
handleResubscribe
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
|
||||
// Backspace mirrors the back arrow on the confirm step, but never while an
|
||||
// editable element is focused (let it delete text there).
|
||||
|
||||
@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
|
||||
|
||||
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
|
||||
const mockHandleSuccessClose = vi.fn()
|
||||
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
|
||||
const mockPreviewData = ref<{ transition_type: string } | null>(null)
|
||||
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
|
||||
useSubscriptionCheckout: () => ({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
useSubscriptionCheckout: mockUseSubscriptionCheckout
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
@@ -91,7 +76,7 @@ const SuccessStub = {
|
||||
function renderComponent(
|
||||
props: {
|
||||
onClose?: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
} = {}
|
||||
) {
|
||||
@@ -121,6 +106,23 @@ function renderComponent(
|
||||
describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseSubscriptionCheckout.mockReturnValue({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
mockCheckoutStep.value = 'pricing'
|
||||
mockPreviewData.value = null
|
||||
})
|
||||
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('passes the reason into subscription checkout', () => {
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
'out_of_credits'
|
||||
)
|
||||
})
|
||||
|
||||
it('shows the team workspace header by default', () => {
|
||||
renderComponent()
|
||||
expect(screen.getByText('Team Workspace')).toBeInTheDocument()
|
||||
|
||||
@@ -116,7 +116,7 @@
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import PricingTableWorkspace from './PricingTableWorkspace.vue'
|
||||
@@ -130,7 +130,7 @@ const {
|
||||
isPersonal = false
|
||||
} = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
}>()
|
||||
|
||||
@@ -154,7 +154,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleResubscribe,
|
||||
handleSuccessClose
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -61,6 +61,9 @@ function onDismiss() {
|
||||
|
||||
function onUpgrade() {
|
||||
dialogStore.closeDialog({ key: 'invite-member-upsell' })
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({
|
||||
planMode: 'team',
|
||||
reason: 'invite_member_upsell'
|
||||
})
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -277,7 +277,7 @@ export function useMembersPanel() {
|
||||
}
|
||||
|
||||
function showTeamPlans() {
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed } from 'vue'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import type { Plan } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
import { findPlanSlug } from './useSubscriptionCheckout'
|
||||
@@ -75,7 +76,9 @@ const {
|
||||
mockPlans,
|
||||
mockResubscribe,
|
||||
mockToastAdd,
|
||||
mockStartOperation
|
||||
mockStartOperation,
|
||||
mockTrackBeginCheckout,
|
||||
mockUserId
|
||||
} = vi.hoisted(() => ({
|
||||
mockSubscribe: vi.fn(),
|
||||
mockPreviewSubscribe: vi.fn(),
|
||||
@@ -84,7 +87,9 @@ const {
|
||||
mockPlans: { value: [] as Plan[] },
|
||||
mockResubscribe: vi.fn(),
|
||||
mockToastAdd: vi.fn(),
|
||||
mockStartOperation: vi.fn()
|
||||
mockStartOperation: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
|
||||
useTelemetry: () => ({
|
||||
trackMonthlySubscriptionSucceeded: vi.fn(),
|
||||
trackBeginCheckout: mockTrackBeginCheckout
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async (importOriginal) => {
|
||||
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
|
||||
describe('useSubscriptionCheckout', () => {
|
||||
let emit: ReturnType<typeof vi.fn>
|
||||
|
||||
async function setup() {
|
||||
async function setup(paymentIntentSource?: PaymentIntentSource) {
|
||||
const { useSubscriptionCheckout } =
|
||||
await import('./useSubscriptionCheckout')
|
||||
return useSubscriptionCheckout(emit as never)
|
||||
return useSubscriptionCheckout(emit as never, paymentIntentSource)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
vi.clearAllMocks()
|
||||
mockPlans.value = allPlans()
|
||||
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
|
||||
mockUserId.value = 'user-1'
|
||||
emit = vi.fn()
|
||||
})
|
||||
|
||||
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
|
||||
cancelUrl: 'https://platform.comfy.org/payment/failed'
|
||||
})
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-team-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the annual plan slug for the yearly cycle', async () => {
|
||||
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Team payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps team checkout_type as change when the preview request fails', async () => {
|
||||
const checkout = await setup()
|
||||
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
|
||||
await checkout.handleSubscribeTeamClick({
|
||||
stop: {
|
||||
id: 'team_1400',
|
||||
usd: 1400,
|
||||
credits: 295_400,
|
||||
discountedUsd: 1295
|
||||
},
|
||||
billingCycle: 'monthly',
|
||||
isChange: true
|
||||
})
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleTeamSubscribe()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'change',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
})
|
||||
|
||||
it('skips begin_checkout when no user id is available', async () => {
|
||||
mockUserId.value = null
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
mockUserId.value = 'user-1'
|
||||
})
|
||||
|
||||
it('fires begin_checkout carrying the payment intent source', async () => {
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'standard',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-1',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('opens payment URL when needs_payment_method', async () => {
|
||||
const checkout = await setup()
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type {
|
||||
Plan,
|
||||
PreviewSubscribeResponse,
|
||||
SubscribeResponse
|
||||
} from '@/platform/workspace/api/workspaceApi'
|
||||
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
type CheckoutStep = 'pricing' | 'preview' | 'success'
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
|
||||
interface SelectedTeamCheckout {
|
||||
stop: TeamPlanSelection
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
}
|
||||
|
||||
/**
|
||||
* Which screen the `preview` step shows. Only a change prorates: a team change
|
||||
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
|
||||
@@ -45,9 +55,12 @@ export function findPlanSlug(
|
||||
return plan?.slug ?? null
|
||||
}
|
||||
|
||||
export function useSubscriptionCheckout(emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
}) {
|
||||
export function useSubscriptionCheckout(
|
||||
emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
},
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
) {
|
||||
const { t } = useI18n()
|
||||
const toast = useToast()
|
||||
const {
|
||||
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
|
||||
const isResubscribing = ref(false)
|
||||
const previewData = ref<PreviewSubscribeResponse | null>(null)
|
||||
const selectedTierKey = ref<CheckoutTierKey | null>(null)
|
||||
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
|
||||
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
|
||||
const selectedBillingCycle = ref<BillingCycle>('yearly')
|
||||
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
|
||||
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
|
||||
const selectedTeamStop = computed(
|
||||
() => selectedTeamCheckout.value?.stop ?? null
|
||||
)
|
||||
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
|
||||
|
||||
const previewVariant = computed<PreviewVariant>(() => {
|
||||
if (selectedTeamStop.value) {
|
||||
if (selectedTeamCheckout.value) {
|
||||
return previewData.value ? 'team-change' : 'team-new'
|
||||
}
|
||||
if (previewData.value) {
|
||||
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
|
||||
billingCycle: BillingCycle
|
||||
isChange?: boolean
|
||||
}) {
|
||||
selectedTeamStop.value = payload.stop
|
||||
selectedTeamCheckout.value = {
|
||||
stop: payload.stop,
|
||||
checkoutType: payload.isChange ? 'change' : 'new'
|
||||
}
|
||||
selectedBillingCycle.value = payload.billingCycle
|
||||
selectedTierKey.value = null
|
||||
previewData.value = null
|
||||
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
|
||||
function handleBackToPricing() {
|
||||
checkoutStep.value = 'pricing'
|
||||
previewData.value = null
|
||||
selectedTeamStop.value = null
|
||||
selectedTeamCheckout.value = null
|
||||
}
|
||||
|
||||
function handleSuccessClose() {
|
||||
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleSubscription() {
|
||||
if (!selectedTierKey.value) return
|
||||
const tierKey = selectedTierKey.value
|
||||
if (!tierKey) return
|
||||
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
const checkoutType =
|
||||
previewData.value &&
|
||||
previewData.value.transition_type !== 'new_subscription'
|
||||
? 'change'
|
||||
: 'new'
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getApiPlanSlug(
|
||||
selectedTierKey.value,
|
||||
selectedBillingCycle.value
|
||||
)
|
||||
const planSlug = getApiPlanSlug(tierKey, billingCycle)
|
||||
if (!planSlug) return
|
||||
const response = await subscribe(planSlug, {
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: tierKey,
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleTeamSubscription() {
|
||||
const stop = selectedTeamStop.value
|
||||
if (!stop?.id) {
|
||||
const teamCheckout = selectedTeamCheckout.value
|
||||
if (!teamCheckout?.stop.id) {
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('subscription.teamPlan.name'),
|
||||
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
|
||||
return
|
||||
}
|
||||
|
||||
const { stop, checkoutType } = teamCheckout
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
|
||||
const planSlug = getTeamPlanSlug(billingCycle)
|
||||
const response = await subscribe(planSlug, {
|
||||
teamCreditStopId: stop.id,
|
||||
billingCycle: selectedBillingCycle.value,
|
||||
billingCycle,
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
|
||||
|
||||
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
|
||||
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingBalanceResponse,
|
||||
BillingStatusResponse,
|
||||
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
subscriptionDialog.show()
|
||||
subscriptionDialog.show({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
subscriptionDialog.show()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
subscriptionDialog.show(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutTier,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
interface TrackWorkspaceCheckoutStartedOptions {
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
billingOpId: string
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export function trackWorkspaceCheckoutStarted({
|
||||
tier,
|
||||
cycle,
|
||||
checkoutType,
|
||||
billingOpId,
|
||||
paymentIntentSource
|
||||
}: TrackWorkspaceCheckoutStartedOptions) {
|
||||
const { userId } = useAuthStore()
|
||||
if (!userId) return
|
||||
|
||||
useTelemetry()?.trackBeginCheckout({
|
||||
user_id: userId,
|
||||
tier,
|
||||
cycle,
|
||||
checkout_type: checkoutType,
|
||||
billing_op_id: billingOpId,
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {})
|
||||
})
|
||||
}
|
||||
208
src/renderer/extensions/linearMode/LinearControls.test.ts
Normal file
208
src/renderer/extensions/linearMode/LinearControls.test.ts
Normal file
@@ -0,0 +1,208 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { render, screen, within } from '@testing-library/vue'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { NodeError } from '@/schemas/apiSchema'
|
||||
import LinearControls from '@/renderer/extensions/linearMode/LinearControls.vue'
|
||||
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
|
||||
import { useAppModeStore } from '@/stores/appModeStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
const billingMock = vi.hoisted(() => ({
|
||||
isActiveSubscription: true
|
||||
}))
|
||||
|
||||
const overlayMock = vi.hoisted(() => ({
|
||||
overlayMessage: 'KSampler is missing a required input: model',
|
||||
overlayTitle: 'Required input missing'
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isActiveSubscription: billingMock.isActiveSubscription
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/components/error/useErrorOverlayState', () => ({
|
||||
useErrorOverlayState: () => ({
|
||||
overlayMessage: overlayMock.overlayMessage,
|
||||
overlayTitle: overlayMock.overlayTitle
|
||||
})
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: {
|
||||
en: {
|
||||
linearMode: {
|
||||
error: {
|
||||
goto: 'Show errors in graph'
|
||||
},
|
||||
mobileNoWorkflow: 'No workflow',
|
||||
runCount: 'Run count',
|
||||
viewJob: 'View job'
|
||||
},
|
||||
menu: {
|
||||
run: 'Run'
|
||||
},
|
||||
menuLabels: {
|
||||
publish: 'Publish'
|
||||
},
|
||||
queue: {
|
||||
jobAddedToQueue: 'Job added to queue',
|
||||
jobQueueing: 'Queueing'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const nodeErrors: Record<string, NodeError> = {
|
||||
'1': {
|
||||
class_type: 'TestNode',
|
||||
dependent_outputs: [],
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing input',
|
||||
details: '',
|
||||
extra_info: { input_name: 'prompt' }
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
function renderControls({
|
||||
hasError = false,
|
||||
isActiveSubscription = true,
|
||||
mobile = false
|
||||
}: {
|
||||
hasError?: boolean
|
||||
isActiveSubscription?: boolean
|
||||
mobile?: boolean
|
||||
} = {}) {
|
||||
billingMock.isActiveSubscription = isActiveSubscription
|
||||
|
||||
const pinia = createTestingPinia({
|
||||
createSpy: vi.fn,
|
||||
stubActions: false
|
||||
})
|
||||
setActivePinia(pinia)
|
||||
|
||||
useAppModeStore().selectedOutputs = [toNodeId(1)]
|
||||
if (hasError) {
|
||||
useExecutionErrorStore().lastNodeErrors = nodeErrors
|
||||
}
|
||||
|
||||
const toastTarget = document.createElement('div')
|
||||
|
||||
return render(LinearControls, {
|
||||
props: { mobile, toastTo: toastTarget },
|
||||
global: {
|
||||
plugins: [pinia, i18n],
|
||||
stubs: {
|
||||
AppModeWidgetList: true,
|
||||
Loader: true,
|
||||
PartnerNodesList: true,
|
||||
Popover: {
|
||||
template: '<div><slot name="button" /><slot /></div>'
|
||||
},
|
||||
ScrubableNumberInput: true,
|
||||
SubscribeToRunButton: true
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('LinearControls', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
billingMock.isActiveSubscription = true
|
||||
overlayMock.overlayMessage = 'KSampler is missing a required input: model'
|
||||
overlayMock.overlayTitle = 'Required input missing'
|
||||
})
|
||||
|
||||
it.for([
|
||||
{ label: 'desktop', mobile: false },
|
||||
{ label: 'mobile', mobile: true }
|
||||
])('shows a workflow error warning in $label controls', ({ mobile }) => {
|
||||
renderControls({ hasError: true, mobile })
|
||||
|
||||
const warning = screen.getByRole('status')
|
||||
expect(
|
||||
within(warning).getByText('Required input missing')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
within(warning).getByText('KSampler is missing a required input: model')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
within(warning).getByRole('button', { name: 'Show errors in graph' })
|
||||
).toBeInTheDocument()
|
||||
expect(within(warning).queryByLabelText('Close')).not.toBeInTheDocument()
|
||||
const runButton = screen.getByRole('button', { name: 'Run' })
|
||||
expect(runButton).toHaveAttribute(
|
||||
'aria-describedby',
|
||||
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
|
||||
)
|
||||
const description = screen.getByTestId(
|
||||
'linear-validation-warning-description'
|
||||
)
|
||||
expect(description).toHaveAttribute(
|
||||
'id',
|
||||
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
|
||||
)
|
||||
expect(description).toHaveTextContent('Required input missing')
|
||||
expect(description).toHaveTextContent(
|
||||
'KSampler is missing a required input: model'
|
||||
)
|
||||
expect(description).not.toHaveTextContent('Show errors in graph')
|
||||
})
|
||||
|
||||
it.for([
|
||||
{ label: 'desktop', mobile: false },
|
||||
{ label: 'mobile', mobile: true }
|
||||
])(
|
||||
'does not show the workflow error warning in $label controls without graph errors',
|
||||
({ mobile }) => {
|
||||
renderControls({ mobile })
|
||||
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Show errors in graph' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'Run' })).not.toHaveAttribute(
|
||||
'aria-describedby'
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
it.for([
|
||||
{ label: 'desktop', mobile: false },
|
||||
{ label: 'mobile', mobile: true }
|
||||
])(
|
||||
'does not show the workflow error warning in $label controls without an active subscription',
|
||||
({ mobile }) => {
|
||||
renderControls({
|
||||
hasError: true,
|
||||
isActiveSubscription: false,
|
||||
mobile
|
||||
})
|
||||
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
}
|
||||
)
|
||||
|
||||
it('does not show the warning when the error copy is empty', () => {
|
||||
overlayMock.overlayMessage = ''
|
||||
|
||||
renderControls({ hasError: true })
|
||||
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'Run' })).not.toHaveAttribute(
|
||||
'aria-describedby'
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -1,10 +1,11 @@
|
||||
<script setup lang="ts">
|
||||
import { useTimeout } from '@vueuse/core'
|
||||
import { storeToRefs } from 'pinia'
|
||||
import { ref, useTemplateRef } from 'vue'
|
||||
import { computed, ref, toValue, useTemplateRef } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import AppModeWidgetList from '@/components/builder/AppModeWidgetList.vue'
|
||||
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
|
||||
import Loader from '@/components/loader/Loader.vue'
|
||||
import ScrubableNumberInput from '@/components/common/ScrubableNumberInput.vue'
|
||||
import Popover from '@/components/ui/Popover.vue'
|
||||
@@ -14,11 +15,15 @@ import SubscribeToRunButton from '@/platform/cloud/subscription/components/Subsc
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import LinearRunErrorWarning from '@/renderer/extensions/linearMode/LinearRunErrorWarning.vue'
|
||||
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
|
||||
import PartnerNodesList from '@/renderer/extensions/linearMode/PartnerNodesList.vue'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import { useQueueSettingsStore } from '@/stores/queueStore'
|
||||
import { useAppMode } from '@/composables/useAppMode'
|
||||
import { useAppModeStore } from '@/stores/appModeStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
|
||||
const { t } = useI18n()
|
||||
const commandStore = useCommandStore()
|
||||
const { batchCount } = storeToRefs(useQueueSettingsStore())
|
||||
@@ -28,6 +33,8 @@ const workflowStore = useWorkflowStore()
|
||||
const { isBuilderMode } = useAppMode()
|
||||
const appModeStore = useAppModeStore()
|
||||
const { hasOutputs } = storeToRefs(appModeStore)
|
||||
const { hasAnyError } = storeToRefs(useExecutionErrorStore())
|
||||
const { overlayMessage } = useErrorOverlayState()
|
||||
|
||||
const { toastTo, mobile } = defineProps<{
|
||||
toastTo?: string | HTMLElement
|
||||
@@ -43,6 +50,13 @@ const { ready: jobToastTimeout, start: resetJobToastTimeout } = useTimeout(
|
||||
{ controls: true, immediate: false }
|
||||
)
|
||||
const widgetListRef = useTemplateRef('widgetListRef')
|
||||
const linearRunButtonTestId = 'linear-run-button'
|
||||
const showRunErrorWarning = computed(
|
||||
() =>
|
||||
hasAnyError.value &&
|
||||
toValue(isActiveSubscription) &&
|
||||
toValue(overlayMessage).trim().length > 0
|
||||
)
|
||||
|
||||
//TODO: refactor out of this file.
|
||||
//code length is small, but changes should propagate
|
||||
@@ -134,9 +148,10 @@ function handleDragDrop() {
|
||||
<PartnerNodesList v-if="!mobile" />
|
||||
<section
|
||||
v-if="mobile"
|
||||
data-testid="linear-run-button"
|
||||
:data-testid="linearRunButtonTestId"
|
||||
class="border-t border-node-component-border p-4 pb-6"
|
||||
>
|
||||
<LinearRunErrorWarning v-if="showRunErrorWarning" />
|
||||
<SubscribeToRunButton
|
||||
v-if="!isActiveSubscription"
|
||||
class="mt-4 w-full"
|
||||
@@ -166,18 +181,24 @@ function handleDragDrop() {
|
||||
variant="primary"
|
||||
class="grow"
|
||||
size="lg"
|
||||
:aria-describedby="
|
||||
showRunErrorWarning
|
||||
? LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
|
||||
: undefined
|
||||
"
|
||||
@click="runButtonClick"
|
||||
>
|
||||
<i class="icon-[lucide--play]" />
|
||||
<i aria-hidden="true" class="icon-[lucide--play]" />
|
||||
{{ t('menu.run') }}
|
||||
</Button>
|
||||
</div>
|
||||
</section>
|
||||
<section
|
||||
v-else
|
||||
data-testid="linear-run-button"
|
||||
:data-testid="linearRunButtonTestId"
|
||||
class="border-t border-node-component-border p-4 pb-6"
|
||||
>
|
||||
<LinearRunErrorWarning v-if="showRunErrorWarning" />
|
||||
<div
|
||||
class="m-1 mb-2 text-node-component-slot-text"
|
||||
v-text="t('linearMode.runCount')"
|
||||
@@ -198,9 +219,14 @@ function handleDragDrop() {
|
||||
variant="primary"
|
||||
class="mt-4 w-full text-sm"
|
||||
size="lg"
|
||||
:aria-describedby="
|
||||
showRunErrorWarning
|
||||
? LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
|
||||
: undefined
|
||||
"
|
||||
@click="runButtonClick"
|
||||
>
|
||||
<i class="icon-[lucide--play]" />
|
||||
<i aria-hidden="true" class="icon-[lucide--play]" />
|
||||
{{ t('menu.run') }}
|
||||
</Button>
|
||||
</section>
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import LinearRunErrorWarning from '@/renderer/extensions/linearMode/LinearRunErrorWarning.vue'
|
||||
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
overlayMessage: 'KSampler is missing a required input: model',
|
||||
overlayTitle: 'Required input missing',
|
||||
viewErrorsInGraph: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/components/error/useErrorOverlayState', () => ({
|
||||
useErrorOverlayState: () => ({
|
||||
overlayMessage: mocks.overlayMessage,
|
||||
overlayTitle: mocks.overlayTitle
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useViewErrorsInGraph', () => ({
|
||||
useViewErrorsInGraph: () => ({
|
||||
viewErrorsInGraph: mocks.viewErrorsInGraph
|
||||
})
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: {
|
||||
en: {
|
||||
linearMode: {
|
||||
error: {
|
||||
goto: 'Show errors in graph'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
function renderWarning() {
|
||||
const user = userEvent.setup()
|
||||
const result = render(LinearRunErrorWarning, {
|
||||
global: { plugins: [i18n] }
|
||||
})
|
||||
|
||||
return { ...result, user }
|
||||
}
|
||||
|
||||
describe('LinearRunErrorWarning', () => {
|
||||
beforeEach(() => {
|
||||
mocks.viewErrorsInGraph.mockReset()
|
||||
})
|
||||
|
||||
it('shows the current error overlay title and message without a close action', () => {
|
||||
renderWarning()
|
||||
|
||||
const warning = screen.getByRole('status')
|
||||
expect(warning).toHaveTextContent('Required input missing')
|
||||
expect(warning).toHaveTextContent(
|
||||
'KSampler is missing a required input: model'
|
||||
)
|
||||
expect(screen.getByText('Required input missing')).toHaveAttribute(
|
||||
'title',
|
||||
'Required input missing'
|
||||
)
|
||||
const description = screen.getByTestId(
|
||||
'linear-validation-warning-description'
|
||||
)
|
||||
expect(description).toHaveAttribute(
|
||||
'id',
|
||||
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
|
||||
)
|
||||
expect(description).toHaveTextContent('Required input missing')
|
||||
expect(description).toHaveTextContent(
|
||||
'KSampler is missing a required input: model'
|
||||
)
|
||||
expect(description).not.toHaveTextContent('Show errors in graph')
|
||||
expect(screen.queryByLabelText('Close')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('opens graph errors when the action is clicked', async () => {
|
||||
const { user } = renderWarning()
|
||||
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Show errors in graph' })
|
||||
)
|
||||
|
||||
expect(mocks.viewErrorsInGraph).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
63
src/renderer/extensions/linearMode/LinearRunErrorWarning.vue
Normal file
63
src/renderer/extensions/linearMode/LinearRunErrorWarning.vue
Normal file
@@ -0,0 +1,63 @@
|
||||
<script setup lang="ts">
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
|
||||
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
|
||||
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
|
||||
|
||||
const { t } = useI18n()
|
||||
const { viewErrorsInGraph } = useViewErrorsInGraph()
|
||||
const { overlayMessage, overlayTitle } = useErrorOverlayState()
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div
|
||||
role="status"
|
||||
data-testid="linear-validation-warning"
|
||||
class="mb-3 flex w-full flex-col gap-2 overflow-hidden rounded-lg border border-l-4 border-border-default border-l-destructive-background bg-base-background p-3 shadow-interface transition-colors duration-200 ease-in-out"
|
||||
>
|
||||
<div
|
||||
:id="LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID"
|
||||
data-testid="linear-validation-warning-description"
|
||||
class="flex flex-col gap-2"
|
||||
>
|
||||
<div class="flex w-full items-start gap-2">
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="mt-0.5 icon-[lucide--circle-x] size-4 shrink-0 text-destructive-background"
|
||||
/>
|
||||
<span
|
||||
class="min-w-0 flex-1 truncate text-sm text-base-foreground"
|
||||
:title="overlayTitle"
|
||||
>
|
||||
{{ overlayTitle }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="flex w-full items-start gap-2"
|
||||
data-testid="linear-validation-warning-message"
|
||||
>
|
||||
<span class="size-4 shrink-0" aria-hidden="true" />
|
||||
<p
|
||||
class="m-0 line-clamp-3 min-w-0 flex-1 text-sm/snug wrap-break-word whitespace-pre-wrap text-muted-foreground"
|
||||
>
|
||||
{{ overlayMessage }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex w-full items-center justify-end pt-2">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="unset"
|
||||
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
|
||||
data-testid="linear-view-errors"
|
||||
@click="viewErrorsInGraph"
|
||||
>
|
||||
{{ t('linearMode.error.goto') }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -0,0 +1,2 @@
|
||||
export const LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID =
|
||||
'linear-run-error-warning'
|
||||
57
src/schemas/nodeDefSchema.test.ts
Normal file
57
src/schemas/nodeDefSchema.test.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
getComboSpecComboOptions,
|
||||
getInputSpecType,
|
||||
isComboInputSpec,
|
||||
isComboInputSpecV1,
|
||||
isComboInputSpecV2,
|
||||
isFloatInputSpec,
|
||||
isIntInputSpec,
|
||||
isMediaUploadComboInput
|
||||
} from './nodeDefSchema'
|
||||
import type {
|
||||
ComboInputSpec,
|
||||
ComboInputSpecV2,
|
||||
InputSpec
|
||||
} from './nodeDefSchema'
|
||||
|
||||
describe('node definition schema helpers', () => {
|
||||
it('identifies input spec variants', () => {
|
||||
const intSpec: InputSpec = ['INT', {}]
|
||||
const floatSpec: InputSpec = ['FLOAT', {}]
|
||||
const comboV1: ComboInputSpec = [['a', 'b'], {}]
|
||||
const comboV2: ComboInputSpecV2 = ['COMBO', { options: ['a', 'b'] }]
|
||||
|
||||
expect(isIntInputSpec(intSpec)).toBe(true)
|
||||
expect(isFloatInputSpec(floatSpec)).toBe(true)
|
||||
expect(isComboInputSpecV1(comboV1)).toBe(true)
|
||||
expect(isComboInputSpecV2(comboV2)).toBe(true)
|
||||
expect(isComboInputSpec(comboV1)).toBe(true)
|
||||
expect(isComboInputSpec(comboV2)).toBe(true)
|
||||
expect(getInputSpecType(comboV1)).toBe('COMBO')
|
||||
expect(getInputSpecType(intSpec)).toBe('INT')
|
||||
})
|
||||
|
||||
it('reads combo options from legacy and v2 combo specs', () => {
|
||||
expect(getComboSpecComboOptions([['a', 1], {}])).toEqual(['a', 1])
|
||||
expect(
|
||||
getComboSpecComboOptions(['COMBO', { options: ['x', 'y'] }])
|
||||
).toEqual(['x', 'y'])
|
||||
expect(getComboSpecComboOptions(['COMBO', {}])).toEqual([])
|
||||
})
|
||||
|
||||
it('detects media upload combo inputs', () => {
|
||||
expect(isMediaUploadComboInput([['a'], { image_upload: true }])).toBe(true)
|
||||
expect(
|
||||
isMediaUploadComboInput(['COMBO', { animated_image_upload: true }])
|
||||
).toBe(true)
|
||||
expect(isMediaUploadComboInput(['COMBO', { video_upload: true }])).toBe(
|
||||
true
|
||||
)
|
||||
expect(isMediaUploadComboInput(['STRING', { image_upload: true }])).toBe(
|
||||
false
|
||||
)
|
||||
expect(isMediaUploadComboInput(['COMBO', undefined])).toBe(false)
|
||||
})
|
||||
})
|
||||
149
src/scripts/api.cloud.test.ts
Normal file
149
src/scripts/api.cloud.test.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
fetchWithUnifiedRemint,
|
||||
shouldRemintCloudRequest
|
||||
} from '@/platform/auth/unified/remintRetry'
|
||||
|
||||
const { mockAuthStore } = vi.hoisted(() => ({
|
||||
mockAuthStore: {
|
||||
isInitialized: true,
|
||||
getAuthHeader: vi.fn(),
|
||||
getAuthToken: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({ isCloud: true }))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: vi.fn(() => mockAuthStore)
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/auth/unified/remintRetry', () => ({
|
||||
fetchWithUnifiedRemint: vi.fn(),
|
||||
shouldRemintCloudRequest: vi.fn()
|
||||
}))
|
||||
|
||||
class FakeWebSocket extends EventTarget {
|
||||
static instances: FakeWebSocket[] = []
|
||||
|
||||
binaryType = ''
|
||||
sent: string[] = []
|
||||
|
||||
constructor(readonly url: string) {
|
||||
super()
|
||||
FakeWebSocket.instances.push(this)
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data)
|
||||
}
|
||||
|
||||
close() {
|
||||
this.dispatchEvent(new Event('close'))
|
||||
}
|
||||
}
|
||||
|
||||
const { ComfyApi } = await import('./api')
|
||||
|
||||
describe('ComfyApi cloud mode', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
FakeWebSocket.instances = []
|
||||
window.name = ''
|
||||
sessionStorage.clear()
|
||||
mockAuthStore.isInitialized = true
|
||||
mockAuthStore.getAuthHeader.mockResolvedValue(null)
|
||||
mockAuthStore.getAuthToken.mockResolvedValue(null)
|
||||
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(false)
|
||||
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
|
||||
new Response(JSON.stringify({ ok: true }), {
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
)
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
|
||||
})
|
||||
|
||||
it('adds cloud auth headers and enables unified retry for authenticated requests', async () => {
|
||||
mockAuthStore.getAuthHeader.mockResolvedValue({
|
||||
Authorization: 'Bearer firebase-token'
|
||||
})
|
||||
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(true)
|
||||
const api = new ComfyApi()
|
||||
api.user = 'cloud-user'
|
||||
|
||||
await api.fetchApi('/queue')
|
||||
|
||||
expect(api.api_base).toBe('')
|
||||
expect(fetchWithUnifiedRemint).toHaveBeenCalledWith(
|
||||
'/api/queue',
|
||||
expect.objectContaining({
|
||||
cache: 'no-cache',
|
||||
headers: {
|
||||
Authorization: 'Bearer firebase-token',
|
||||
'Comfy-User': 'cloud-user'
|
||||
}
|
||||
}),
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('continues cloud fetches when auth header lookup fails', async () => {
|
||||
mockAuthStore.getAuthHeader.mockRejectedValue(new Error('auth unavailable'))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const api = new ComfyApi()
|
||||
|
||||
await api.fetchApi('/history', {
|
||||
headers: [['X-Test', '1']]
|
||||
})
|
||||
|
||||
const [, options, retryOn401] = vi.mocked(fetchWithUnifiedRemint).mock
|
||||
.calls[0]
|
||||
expect(options.headers).toEqual([
|
||||
['X-Test', '1'],
|
||||
['Comfy-User', '']
|
||||
])
|
||||
expect(retryOn401).toBe(false)
|
||||
expect(shouldRemintCloudRequest).not.toHaveBeenCalled()
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Failed to get auth header:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
|
||||
it('adds the cloud auth token to websocket URLs', async () => {
|
||||
mockAuthStore.getAuthToken.mockResolvedValue('socket-token')
|
||||
window.name = 'client-1'
|
||||
const api = new ComfyApi()
|
||||
|
||||
api.init()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(FakeWebSocket.instances).toHaveLength(1)
|
||||
})
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
expect(socket.url).toContain('clientId=client-1')
|
||||
expect(socket.url).toContain('token=socket-token')
|
||||
})
|
||||
|
||||
it('opens a cloud websocket without a token when token lookup fails', async () => {
|
||||
mockAuthStore.getAuthToken.mockRejectedValue(new Error('token unavailable'))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const api = new ComfyApi()
|
||||
|
||||
api.init()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(FakeWebSocket.instances).toHaveLength(1)
|
||||
})
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
expect(socket.url).not.toContain('token=')
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Could not get auth token for WebSocket connection:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
})
|
||||
460
src/scripts/api.core.test.ts
Normal file
460
src/scripts/api.core.test.ts
Normal file
@@ -0,0 +1,460 @@
|
||||
import axios from 'axios'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
|
||||
import type {
|
||||
ComfyApiWorkflow,
|
||||
ComfyWorkflowJSON
|
||||
} from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { PromptResponse } from '@/schemas/apiSchema'
|
||||
import { api as sharedApi, ComfyApi, PromptExecutionError } from '@/scripts/api'
|
||||
import type { NodeExecutionId } from '@/types/nodeIdentification'
|
||||
|
||||
const fetchJobs = vi.hoisted(() => ({
|
||||
fetchHistory: vi.fn(),
|
||||
fetchJobDetail: vi.fn(),
|
||||
fetchQueue: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('axios')
|
||||
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => fetchJobs)
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
function jsonResponse(data: unknown, init: ResponseInit = {}) {
|
||||
return new Response(JSON.stringify(data), {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
...init
|
||||
})
|
||||
}
|
||||
|
||||
function promptData(): {
|
||||
output: ComfyApiWorkflow
|
||||
workflow: ComfyWorkflowJSON
|
||||
} {
|
||||
return {
|
||||
output: fromPartial<ComfyApiWorkflow>({
|
||||
1: {
|
||||
inputs: {},
|
||||
class_type: 'KSampler',
|
||||
_meta: { title: 'KSampler' }
|
||||
}
|
||||
}),
|
||||
workflow: fromPartial<ComfyWorkflowJSON>({
|
||||
version: 0.4,
|
||||
nodes: [],
|
||||
links: []
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function requestBody(fetchApi: ReturnType<typeof vi.spyOn>, call = 0) {
|
||||
const init = fetchApi.mock.calls[call][1]
|
||||
return JSON.parse(String(init?.body)) as Record<string, unknown>
|
||||
}
|
||||
|
||||
describe('PromptExecutionError', () => {
|
||||
it('formats string and node-specific prompt errors', () => {
|
||||
const response = fromPartial<PromptResponse>({
|
||||
error: 'invalid prompt',
|
||||
node_errors: {
|
||||
7: {
|
||||
class_type: 'KSampler',
|
||||
dependent_outputs: [],
|
||||
errors: [{ message: 'bad seed', details: 'seed must be numeric' }]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(new PromptExecutionError(response, 400).toString()).toBe(
|
||||
'invalid prompt\nKSampler:\n - bad seed: seed must be numeric'
|
||||
)
|
||||
})
|
||||
|
||||
it('formats structured prompt errors without node errors', () => {
|
||||
const response = fromPartial<PromptResponse>({
|
||||
error: {
|
||||
type: 'prompt_outputs_failed_validation',
|
||||
message: 'Validation failed',
|
||||
details: 'missing node'
|
||||
}
|
||||
})
|
||||
|
||||
expect(new PromptExecutionError(response).toString()).toBe(
|
||||
'Validation failed: missing node'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi queuePrompt', () => {
|
||||
it('sends front queue requests with auth and execution options', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.clientId = 'client-1'
|
||||
api.authToken = 'auth-token'
|
||||
api.apiKey = 'api-key'
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
|
||||
|
||||
await api.queuePrompt(-1, promptData(), {
|
||||
partialExecutionTargets: ['9:10' as NodeExecutionId],
|
||||
previewMethod: 'auto'
|
||||
})
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith('/prompt', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: fetchApi.mock.calls[0][1]?.body
|
||||
})
|
||||
expect(typeof fetchApi.mock.calls[0][1]?.body).toBe('string')
|
||||
expect(requestBody(fetchApi)).toMatchObject({
|
||||
client_id: 'client-1',
|
||||
front: true,
|
||||
partial_execution_targets: ['9:10'],
|
||||
extra_data: {
|
||||
auth_token_comfy_org: 'auth-token',
|
||||
api_key_comfy_org: 'api-key',
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
preview_method: 'auto'
|
||||
}
|
||||
})
|
||||
expect(requestBody(fetchApi)).not.toHaveProperty('number')
|
||||
})
|
||||
|
||||
it('omits default-only queue options and sets explicit queue numbers', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockImplementation(() =>
|
||||
Promise.resolve(jsonResponse({ prompt_id: 'queued' }))
|
||||
)
|
||||
|
||||
await api.queuePrompt(0, promptData(), { previewMethod: 'default' })
|
||||
await api.queuePrompt(4, promptData())
|
||||
|
||||
expect(requestBody(fetchApi, 0)).toMatchObject({ client_id: '' })
|
||||
expect(requestBody(fetchApi, 0)).not.toHaveProperty('front')
|
||||
expect(requestBody(fetchApi, 0)).not.toHaveProperty('number')
|
||||
expect(
|
||||
requestBody(fetchApi, 0).extra_data as Record<string, unknown>
|
||||
).not.toHaveProperty('preview_method')
|
||||
expect(requestBody(fetchApi, 1)).toMatchObject({ number: 4 })
|
||||
})
|
||||
|
||||
it('throws parsed prompt errors from non-200 responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
jsonResponse(
|
||||
{
|
||||
error: {
|
||||
type: 'server_error',
|
||||
message: 'Server rejected prompt',
|
||||
details: 'bad output'
|
||||
}
|
||||
},
|
||||
{ status: 400, statusText: 'Bad Request' }
|
||||
)
|
||||
)
|
||||
|
||||
await expect(api.queuePrompt(0, promptData())).rejects.toThrow(
|
||||
'Prompt execution failed'
|
||||
)
|
||||
})
|
||||
|
||||
it('wraps non-json prompt errors with status details', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
new Response('backend exploded', {
|
||||
status: 500,
|
||||
statusText: 'Server Error'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(api.queuePrompt(0, promptData())).rejects.toMatchObject({
|
||||
status: 500,
|
||||
response: {
|
||||
error: {
|
||||
message: '500 Server Error',
|
||||
details: 'backend exploded'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi read helpers', () => {
|
||||
it('returns localized templates, default templates, and empty non-json responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: [{ name: 'localized' }]
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'text/html' },
|
||||
data: '<html></html>'
|
||||
})
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
|
||||
{ name: 'localized' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
|
||||
expect(vi.mocked(axios.get).mock.calls[0][0]).toContain(
|
||||
'/templates/index.fr.json'
|
||||
)
|
||||
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
|
||||
'/templates/index.json'
|
||||
)
|
||||
})
|
||||
|
||||
it('falls back from missing localized templates to the default index', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.mocked(axios.get)
|
||||
.mockRejectedValueOnce(new Error('missing locale'))
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: [{ name: 'default' }]
|
||||
})
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('ja')).resolves.toEqual([
|
||||
{ name: 'default' }
|
||||
])
|
||||
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
|
||||
'/templates/index.json'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty model lists for 404s and filters internal folders', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse([
|
||||
{ name: 'checkpoints' },
|
||||
{ name: 'configs' },
|
||||
{ name: 'custom_nodes' }
|
||||
])
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(jsonResponse(['model.safetensors']))
|
||||
|
||||
await expect(api.getModelFolders()).resolves.toEqual([])
|
||||
await expect(api.getModelFolders()).resolves.toEqual([
|
||||
{ name: 'checkpoints' }
|
||||
])
|
||||
await expect(api.getModels('checkpoints')).resolves.toEqual([])
|
||||
await expect(api.getModels('checkpoints')).resolves.toEqual([
|
||||
'model.safetensors'
|
||||
])
|
||||
expect(fetchApi).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
|
||||
it('handles model metadata text responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(new Response(''))
|
||||
.mockResolvedValueOnce(new Response('{"format":"safetensors"}'))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('not json', { status: 200, statusText: 'OK' })
|
||||
)
|
||||
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toBe(null)
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toEqual({ format: 'safetensors' })
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toBe(null)
|
||||
})
|
||||
|
||||
it('gets fuse options only from json responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: { keys: ['name'] }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'text/plain' },
|
||||
data: 'nope'
|
||||
})
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
vi.mocked(axios.get).mockRejectedValueOnce(new Error('missing'))
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi queue and data helpers', () => {
|
||||
it('routes item collection requests to queue or history', async () => {
|
||||
const api = new ComfyApi()
|
||||
const queue = vi.spyOn(api, 'getQueue').mockResolvedValue({
|
||||
Running: [],
|
||||
Pending: []
|
||||
})
|
||||
const historyItem = fromPartial<JobListItem>({
|
||||
id: 'history-1',
|
||||
status: 'completed',
|
||||
create_time: 1,
|
||||
priority: 0
|
||||
})
|
||||
const history = vi.spyOn(api, 'getHistory').mockResolvedValue([historyItem])
|
||||
|
||||
await expect(api.getItems('queue')).resolves.toEqual({
|
||||
Running: [],
|
||||
Pending: []
|
||||
})
|
||||
await expect(api.getItems('history')).resolves.toEqual([historyItem])
|
||||
expect(queue).toHaveBeenCalledOnce()
|
||||
expect(history).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('returns queue fallbacks unless errors are requested', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
fetchJobs.fetchQueue.mockRejectedValue(new Error('network'))
|
||||
|
||||
await expect(api.getQueue()).resolves.toEqual({ Running: [], Pending: [] })
|
||||
await expect(api.getQueue({ throwOnError: true })).rejects.toThrow(
|
||||
'network'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty history when fetchHistory fails', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
fetchJobs.fetchHistory.mockRejectedValue(new Error('history down'))
|
||||
|
||||
await expect(api.getHistory()).resolves.toEqual([])
|
||||
})
|
||||
|
||||
it('posts item mutations with and without request bodies', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi.spyOn(api, 'fetchApi').mockResolvedValue(new Response())
|
||||
|
||||
await api.deleteItem('history', 'job-1')
|
||||
await api.clearItems('queue')
|
||||
await api.interrupt(null)
|
||||
await api.interrupt('running-1')
|
||||
|
||||
expect(fetchApi.mock.calls.map((call) => call[0])).toEqual([
|
||||
'/history',
|
||||
'/queue',
|
||||
'/interrupt',
|
||||
'/interrupt'
|
||||
])
|
||||
expect(fetchApi.mock.calls[0][1]?.body).toBe(
|
||||
JSON.stringify({ delete: ['job-1'] })
|
||||
)
|
||||
expect(fetchApi.mock.calls[1][1]?.body).toBe(
|
||||
JSON.stringify({ clear: true })
|
||||
)
|
||||
expect(fetchApi.mock.calls[2][1]?.body).toBeUndefined()
|
||||
expect(fetchApi.mock.calls[3][1]?.body).toBe(
|
||||
JSON.stringify({ prompt_id: 'running-1' })
|
||||
)
|
||||
})
|
||||
|
||||
it('throws unauthorized settings responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
new Response('', { status: 401, statusText: 'Unauthorized' })
|
||||
)
|
||||
|
||||
await expect(api.getSettings()).rejects.toThrow('Unauthorized')
|
||||
})
|
||||
|
||||
it('stores user data with default and raw-body options', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(new Response('', { status: 200 }))
|
||||
const raw = new Blob(['raw'])
|
||||
|
||||
await api.storeUserData('a/b.json', { ok: true })
|
||||
await api.storeUserData('raw.bin', raw, {
|
||||
overwrite: false,
|
||||
stringify: false,
|
||||
throwOnError: false,
|
||||
full_info: true
|
||||
})
|
||||
|
||||
expect(fetchApi.mock.calls[0][0]).toBe(
|
||||
'/userdata/a%2Fb.json?overwrite=true&full_info=false'
|
||||
)
|
||||
expect(fetchApi.mock.calls[0][1]?.body).toBe(JSON.stringify({ ok: true }))
|
||||
expect(fetchApi.mock.calls[1][0]).toBe(
|
||||
'/userdata/raw.bin?overwrite=false&full_info=true'
|
||||
)
|
||||
expect(fetchApi.mock.calls[1][1]?.body).toBe(raw)
|
||||
})
|
||||
|
||||
it('honors storeUserData throwOnError', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
|
||||
await expect(api.storeUserData('bad.json', {})).rejects.toThrow(
|
||||
"Error storing user data file 'bad.json': 500 Server Error"
|
||||
)
|
||||
await expect(
|
||||
api.storeUserData('bad.json', {}, { throwOnError: false })
|
||||
).resolves.toHaveProperty('status', 500)
|
||||
})
|
||||
|
||||
it('lists full user data info by status', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse([{ path: 'x' }]))
|
||||
|
||||
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([])
|
||||
await expect(api.listUserDataFullInfo('models/')).rejects.toThrow(
|
||||
"Error getting user data list 'models': 500 Server Error"
|
||||
)
|
||||
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([
|
||||
{ path: 'x' }
|
||||
])
|
||||
})
|
||||
|
||||
it('loads global subgraph records and deferred data', async () => {
|
||||
const fetchApi = vi
|
||||
.spyOn(sharedApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({
|
||||
ready: { name: 'Ready', info: { node_pack: 'core' }, data: '{}' },
|
||||
lazy: { name: 'Lazy', info: { node_pack: 'core' } }
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({ data: '{"lazy":true}' }))
|
||||
|
||||
await expect(sharedApi.getGlobalSubgraphs()).resolves.toEqual({})
|
||||
const subgraphs = await sharedApi.getGlobalSubgraphs()
|
||||
|
||||
expect(subgraphs.ready.data).toBe('{}')
|
||||
expect(subgraphs.lazy.data).toBeInstanceOf(Promise)
|
||||
await expect(subgraphs.lazy.data).resolves.toBe('{"lazy":true}')
|
||||
expect(fetchApi).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
829
src/scripts/api.test.ts
Normal file
829
src/scripts/api.test.ts
Normal file
@@ -0,0 +1,829 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import axios from 'axios'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
api as singletonApi,
|
||||
ComfyApi,
|
||||
PromptExecutionError,
|
||||
UnauthorizedError
|
||||
} from '@/scripts/api'
|
||||
import type { ComfyApiWorkflow } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { NodeExecutionId } from '@/types/nodeIdentification'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
|
||||
const { mockToastStore } = vi.hoisted(() => ({
|
||||
mockToastStore: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/auth/unified/remintRetry', () => ({
|
||||
fetchWithUnifiedRemint: vi.fn(),
|
||||
shouldRemintCloudRequest: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => mockToastStore)
|
||||
}))
|
||||
|
||||
vi.mock('axios', () => ({
|
||||
default: {
|
||||
get: vi.fn(),
|
||||
patch: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
class FakeWebSocket extends EventTarget {
|
||||
static instances: FakeWebSocket[] = []
|
||||
|
||||
binaryType = ''
|
||||
sent: string[] = []
|
||||
|
||||
constructor(readonly url: string) {
|
||||
super()
|
||||
FakeWebSocket.instances.push(this)
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data)
|
||||
}
|
||||
|
||||
close() {
|
||||
this.dispatchEvent(new Event('close'))
|
||||
}
|
||||
}
|
||||
|
||||
function jsonResponse(data: unknown, init: ResponseInit = {}) {
|
||||
return new Response(JSON.stringify(data), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
...init
|
||||
})
|
||||
}
|
||||
|
||||
function createWorkflow() {
|
||||
return {
|
||||
last_node_id: 0,
|
||||
last_link_id: 0,
|
||||
nodes: [],
|
||||
links: [],
|
||||
groups: [],
|
||||
config: {},
|
||||
extra: {},
|
||||
version: 0.4
|
||||
}
|
||||
}
|
||||
|
||||
function binaryMessage(type: number, payload: Uint8Array) {
|
||||
const bytes = new Uint8Array(4 + payload.length)
|
||||
new DataView(bytes.buffer).setUint32(0, type)
|
||||
bytes.set(payload, 4)
|
||||
return bytes.buffer
|
||||
}
|
||||
|
||||
function uint32(value: number) {
|
||||
const bytes = new Uint8Array(4)
|
||||
new DataView(bytes.buffer).setUint32(0, value)
|
||||
return bytes
|
||||
}
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
describe('PromptExecutionError', () => {
|
||||
it('formats string, structured, and node-level errors', () => {
|
||||
expect(
|
||||
new PromptExecutionError({
|
||||
error: 'Queue rejected',
|
||||
node_errors: {}
|
||||
}).toString()
|
||||
).toBe('Queue rejected')
|
||||
|
||||
expect(
|
||||
new PromptExecutionError({
|
||||
error: {
|
||||
type: 'invalid_prompt',
|
||||
message: 'Invalid prompt',
|
||||
details: 'missing input'
|
||||
},
|
||||
node_errors: {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
dependent_outputs: ['1'],
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Required input',
|
||||
details: 'source'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}).toString()
|
||||
).toContain('Invalid prompt: missing input\nPreviewAny:')
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
FakeWebSocket.instances = []
|
||||
window.name = ''
|
||||
sessionStorage.clear()
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('builds API, internal, file, and fetch URLs with user headers', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.user = 'reviewer'
|
||||
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
|
||||
jsonResponse({ ok: true })
|
||||
)
|
||||
|
||||
await api.fetchApi('/queue', {
|
||||
headers: new Headers([['X-Test', '1']])
|
||||
})
|
||||
|
||||
expect(api.apiURL('/api/custom')).toBe(`${api.api_base}/api/custom`)
|
||||
expect(api.apiURL('/queue')).toBe(`${api.api_base}/api/queue`)
|
||||
expect(api.internalURL('/logs')).toBe(`${api.api_base}/internal/logs`)
|
||||
expect(api.fileURL('/view')).toBe(`${api.api_base}/view`)
|
||||
const [, options] = vi.mocked(fetchWithUnifiedRemint).mock.calls[0]
|
||||
expect(options.headers).toBeInstanceOf(Headers)
|
||||
expect((options.headers as Headers).get('Comfy-User')).toBe('reviewer')
|
||||
expect((options.headers as Headers).get('X-Test')).toBe('1')
|
||||
})
|
||||
|
||||
it('guards event listeners and still allows removing them', async () => {
|
||||
const api = new ComfyApi()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const listener = vi.fn()
|
||||
const throwingListener = vi.fn(() => {
|
||||
throw new Error('listener failed')
|
||||
})
|
||||
const asyncListener = vi.fn(() => Promise.reject(new Error('async failed')))
|
||||
const objectListener = { handleEvent: vi.fn() }
|
||||
|
||||
api.addEventListener('status', null)
|
||||
api.removeEventListener('status', null)
|
||||
api.addEventListener('status', listener)
|
||||
api.addEventListener('status', throwingListener)
|
||||
api.addEventListener('status', asyncListener)
|
||||
api.addEventListener('status', fromAny(objectListener))
|
||||
|
||||
api.dispatchCustomEvent('status', { exec_info: { queue_remaining: 1 } })
|
||||
await Promise.resolve()
|
||||
|
||||
expect(listener).toHaveBeenCalled()
|
||||
expect(throwingListener).toHaveBeenCalled()
|
||||
expect(asyncListener).toHaveBeenCalled()
|
||||
expect(objectListener.handleEvent).toHaveBeenCalled()
|
||||
expect(warn).toHaveBeenCalledTimes(2)
|
||||
|
||||
api.removeEventListener('status', listener)
|
||||
api.dispatchCustomEvent('status', null)
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('reuses guarded listener wrappers and ignores unknown removals', () => {
|
||||
const api = new ComfyApi()
|
||||
const listener = vi.fn()
|
||||
const neverRegistered = vi.fn()
|
||||
|
||||
api.addEventListener('status', listener)
|
||||
api.addEventListener('status', listener)
|
||||
api.removeEventListener('status', neverRegistered)
|
||||
api.dispatchCustomEvent('status', null)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
expect(neverRegistered).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('supports guarded custom event listeners', () => {
|
||||
const api = new ComfyApi()
|
||||
const listener = vi.fn()
|
||||
|
||||
api.addCustomEventListener('custom-node-event', listener)
|
||||
;(api as EventTarget).dispatchEvent(
|
||||
new CustomEvent('custom-node-event', { detail: { ok: true } })
|
||||
)
|
||||
api.removeCustomEventListener('custom-node-event', listener)
|
||||
;(api as EventTarget).dispatchEvent(
|
||||
new CustomEvent('custom-node-event', { detail: { ok: false } })
|
||||
)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('routes websocket JSON messages and custom registered messages', () => {
|
||||
window.name = 'existing-client'
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const executing = vi.fn()
|
||||
const featureFlags = vi.fn()
|
||||
const custom = vi.fn()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
vi.spyOn(console, 'log').mockImplementation(() => undefined)
|
||||
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('executing', executing)
|
||||
api.addEventListener('feature_flags', featureFlags)
|
||||
api.addCustomEventListener('custom-message', custom)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
socket.dispatchEvent(new Event('open'))
|
||||
|
||||
expect(socket.url).toContain('clientId=existing-client')
|
||||
expect(JSON.parse(socket.sent[0])).toMatchObject({
|
||||
type: 'feature_flags'
|
||||
})
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
data: {
|
||||
sid: 'fresh-client',
|
||||
status: { exec_info: { queue_remaining: 2 } }
|
||||
}
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'executing',
|
||||
data: { node: '12' }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'feature_flags',
|
||||
data: { supports_progress_text_metadata: true }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'custom-message',
|
||||
data: { from: 'extension' }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'unknown-message',
|
||||
data: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'unknown-message',
|
||||
data: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(api.clientId).toBe('fresh-client')
|
||||
expect(window.name).toBe('fresh-client')
|
||||
expect(sessionStorage.getItem('clientId')).toBe('fresh-client')
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { exec_info: { queue_remaining: 2 } }
|
||||
})
|
||||
)
|
||||
expect(executing).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: '12' })
|
||||
)
|
||||
expect(featureFlags).toHaveBeenCalled()
|
||||
expect(api.serverSupportsFeature('supports_progress_text_metadata')).toBe(
|
||||
true
|
||||
)
|
||||
expect(custom).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: { from: 'extension' } })
|
||||
)
|
||||
expect(api.reportedUnknownMessageTypes.has('unknown-message')).toBe(true)
|
||||
expect(warn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('polls status when the initial websocket connection fails', async () => {
|
||||
vi.useFakeTimers()
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({ exec_info: { queue_remaining: 4 } })
|
||||
)
|
||||
.mockRejectedValueOnce(new Error('poll failed'))
|
||||
api.addEventListener('status', status)
|
||||
|
||||
api.init()
|
||||
FakeWebSocket.instances[0].dispatchEvent(new Event('error'))
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { exec_info: { queue_remaining: 4 } }
|
||||
})
|
||||
)
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('emits reconnect lifecycle events after an opened websocket closes', async () => {
|
||||
vi.useFakeTimers()
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const reconnecting = vi.fn()
|
||||
const reconnected = vi.fn()
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('reconnecting', reconnecting)
|
||||
api.addEventListener('reconnected', reconnected)
|
||||
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
socket.dispatchEvent(new Event('open'))
|
||||
socket.close()
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
expect(reconnecting).toHaveBeenCalledOnce()
|
||||
|
||||
await vi.advanceTimersByTimeAsync(300)
|
||||
const reconnectSocket = FakeWebSocket.instances[1]
|
||||
reconnectSocket.dispatchEvent(new Event('open'))
|
||||
|
||||
expect(reconnected).toHaveBeenCalledOnce()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('routes websocket variants without session ids and display-node fallbacks', () => {
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const executing = vi.fn()
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('executing', executing)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
data: { status: undefined }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'executing',
|
||||
data: { node: 'real', display_node: 'display' }
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
expect(executing).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'display' })
|
||||
)
|
||||
})
|
||||
|
||||
it('routes binary preview and progress websocket messages', () => {
|
||||
const api = new ComfyApi()
|
||||
const preview = vi.fn()
|
||||
const previewWithMetadata = vi.fn()
|
||||
const progressText = vi.fn()
|
||||
const encoder = new TextEncoder()
|
||||
api.serverFeatureFlags.value = {
|
||||
supports_progress_text_metadata: true
|
||||
}
|
||||
api.addEventListener('b_preview', preview)
|
||||
api.addEventListener('b_preview_with_metadata', previewWithMetadata)
|
||||
api.addEventListener('progress_text', progressText)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(2), new Uint8Array([1, 2])))
|
||||
})
|
||||
)
|
||||
|
||||
const promptId = encoder.encode('prompt-1')
|
||||
const nodeId = encoder.encode('7')
|
||||
const progressPayload = concatBytes(
|
||||
uint32(promptId.length),
|
||||
promptId,
|
||||
uint32(nodeId.length),
|
||||
nodeId,
|
||||
encoder.encode('loading')
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(3, progressPayload)
|
||||
})
|
||||
)
|
||||
|
||||
const metadata = encoder.encode(
|
||||
JSON.stringify({
|
||||
image_type: 'image/webp',
|
||||
node_id: '7',
|
||||
display_node_id: '7',
|
||||
parent_node_id: '4',
|
||||
real_node_id: '7',
|
||||
prompt_id: 'prompt-1'
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(
|
||||
4,
|
||||
concatBytes(uint32(metadata.length), metadata, new Uint8Array([9]))
|
||||
)
|
||||
})
|
||||
)
|
||||
|
||||
expect(preview).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({ type: 'image/png' })
|
||||
})
|
||||
)
|
||||
expect(progressText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: {
|
||||
nodeId: '7',
|
||||
text: 'loading',
|
||||
prompt_id: 'prompt-1'
|
||||
}
|
||||
})
|
||||
)
|
||||
expect(previewWithMetadata).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({
|
||||
nodeId: '7',
|
||||
parentNodeId: '4',
|
||||
jobId: 'prompt-1'
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('routes binary jpeg/default previews and malformed binary messages defensively', () => {
|
||||
const api = new ComfyApi()
|
||||
const preview = vi.fn()
|
||||
const progressText = vi.fn()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const encoder = new TextEncoder()
|
||||
api.addEventListener('b_preview', preview)
|
||||
api.addEventListener('progress_text', progressText)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(1), new Uint8Array([1])))
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(99), new Uint8Array([2])))
|
||||
})
|
||||
)
|
||||
|
||||
const nodeId = encoder.encode('node')
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(
|
||||
3,
|
||||
concatBytes(uint32(nodeId.length), nodeId, encoder.encode('ready'))
|
||||
)
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(3, new Uint8Array([1]))
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(99, new Uint8Array())
|
||||
})
|
||||
)
|
||||
|
||||
expect(preview).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({ type: 'image/jpeg' })
|
||||
})
|
||||
)
|
||||
expect(progressText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { nodeId: 'node', text: 'ready' }
|
||||
})
|
||||
)
|
||||
expect(warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('serializes prompt queue options and surfaces non-200 errors', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.clientId = 'client-1'
|
||||
api.authToken = 'token-1'
|
||||
api.apiKey = 'key-1'
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ prompt_id: 'queued' }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('backend exploded', {
|
||||
status: 500,
|
||||
statusText: 'Server Error'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
api.queuePrompt(
|
||||
-1,
|
||||
{ output: prompt, workflow: createWorkflow() },
|
||||
{
|
||||
partialExecutionTargets: ['7' as NodeExecutionId],
|
||||
previewMethod: 'latent2rgb'
|
||||
}
|
||||
)
|
||||
).resolves.toEqual({ prompt_id: 'queued' })
|
||||
|
||||
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
|
||||
expect(body).toMatchObject({
|
||||
client_id: 'client-1',
|
||||
prompt,
|
||||
partial_execution_targets: ['7'],
|
||||
front: true,
|
||||
extra_data: {
|
||||
auth_token_comfy_org: 'token-1',
|
||||
api_key_comfy_org: 'key-1',
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
preview_method: 'latent2rgb'
|
||||
}
|
||||
})
|
||||
expect(body.number).toBeUndefined()
|
||||
|
||||
await expect(
|
||||
api.queuePrompt(3, { output: prompt, workflow: createWorkflow() })
|
||||
).rejects.toMatchObject({ status: 500 })
|
||||
})
|
||||
|
||||
it('omits queue position and default preview method for normal queueing', async () => {
|
||||
const api = new ComfyApi()
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
|
||||
|
||||
await api.queuePrompt(
|
||||
0,
|
||||
{ output: prompt, workflow: createWorkflow() },
|
||||
{
|
||||
previewMethod: 'default'
|
||||
}
|
||||
)
|
||||
|
||||
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
|
||||
expect(body.front).toBeUndefined()
|
||||
expect(body.number).toBeUndefined()
|
||||
expect(body.extra_data.preview_method).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles shareable assets, settings, userdata, subgraphs, and memory APIs', async () => {
|
||||
const api = new ComfyApi()
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ assets: [] }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 401, statusText: 'Unauthorized' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 204 }))
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({}, { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
vi.spyOn(singletonApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ data: 'subgraph-data' }))
|
||||
.mockResolvedValueOnce(jsonResponse({ missing: {} }))
|
||||
.mockResolvedValueOnce(jsonResponse({ one: { data: 'inline' } }))
|
||||
|
||||
await expect(
|
||||
api.getShareableAssets(prompt, { owned: false })
|
||||
).resolves.toEqual({ assets: [] })
|
||||
await expect(api.getShareableAssets(prompt)).rejects.toThrow(
|
||||
'Failed to fetch shareable assets'
|
||||
)
|
||||
await expect(api.getSettings()).rejects.toBeInstanceOf(UnauthorizedError)
|
||||
await expect(
|
||||
api.storeUserData('plain.txt', 'raw', {
|
||||
overwrite: false,
|
||||
stringify: false,
|
||||
throwOnError: false,
|
||||
full_info: true
|
||||
})
|
||||
).resolves.toHaveProperty('status', 204)
|
||||
await expect(api.listUserDataFullInfo('/missing/')).resolves.toEqual([])
|
||||
await expect(api.listUserDataFullInfo('/broken/')).rejects.toThrow(
|
||||
"Error getting user data list '/broken'"
|
||||
)
|
||||
await expect(api.getGlobalSubgraphData('one')).resolves.toBe(
|
||||
'subgraph-data'
|
||||
)
|
||||
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
|
||||
"Global subgraph 'missing' returned empty data"
|
||||
)
|
||||
await expect(api.getGlobalSubgraphs()).resolves.toEqual({
|
||||
one: { data: 'inline' }
|
||||
})
|
||||
await api.freeMemory({ freeExecutionCache: true })
|
||||
await api.freeMemory({ freeExecutionCache: false })
|
||||
await api.freeMemory({ freeExecutionCache: true })
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
summary: 'Models and Execution Cache have been cleared.'
|
||||
})
|
||||
)
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ summary: 'Models have been unloaded.' })
|
||||
)
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
summary:
|
||||
'Unloading of models failed. Installed ComfyUI may be an outdated version.'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects non-success global subgraph data responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(singletonApi, 'fetchApi').mockResolvedValueOnce(
|
||||
jsonResponse({}, { status: 404, statusText: 'Not Found' })
|
||||
)
|
||||
|
||||
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
|
||||
"Failed to fetch global subgraph 'missing': 404 Not Found"
|
||||
)
|
||||
})
|
||||
|
||||
it('handles successful settings and userdata helper request shapes', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ theme: 'dark' }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
|
||||
await expect(api.getSettings()).resolves.toEqual({ theme: 'dark' })
|
||||
await expect(api.storeUserData('bad.json', { a: 1 })).rejects.toThrow(
|
||||
"Error storing user data file 'bad.json'"
|
||||
)
|
||||
await api.moveUserData('old/path.json', 'new path.json')
|
||||
await api.deleteUserData('old/path.json')
|
||||
|
||||
expect(fetchApi.mock.calls[2]).toEqual([
|
||||
'/userdata/old%2Fpath.json/move/new%20path.json?overwrite=false',
|
||||
{ method: 'POST' }
|
||||
])
|
||||
expect(fetchApi.mock.calls[3]).toEqual([
|
||||
'/userdata/old%2Fpath.json',
|
||||
{ method: 'DELETE' }
|
||||
])
|
||||
})
|
||||
|
||||
it('handles global subgraph fallbacks and log endpoints', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(singletonApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({
|
||||
missing: {
|
||||
name: 'Missing data',
|
||||
info: { node_pack: 'core' }
|
||||
}
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({ data: 'lazy-data' }))
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({ data: 'log text', headers: {} })
|
||||
.mockResolvedValueOnce({ data: { logs: [] }, headers: {} })
|
||||
.mockRejectedValueOnce(new Error('no folders'))
|
||||
.mockResolvedValueOnce({
|
||||
data: { checkpoints: ['/models'] },
|
||||
headers: {}
|
||||
})
|
||||
vi.mocked(axios.patch).mockResolvedValue({ data: undefined })
|
||||
|
||||
await expect(api.getGlobalSubgraphs()).resolves.toEqual({})
|
||||
const subgraphs = await api.getGlobalSubgraphs()
|
||||
await expect(subgraphs.missing.data).resolves.toBe('lazy-data')
|
||||
await expect(api.getLogs()).resolves.toBe('log text')
|
||||
await expect(api.getRawLogs()).resolves.toEqual({ logs: [] })
|
||||
await api.subscribeLogs(true)
|
||||
await expect(api.getFolderPaths()).resolves.toEqual({})
|
||||
await expect(api.getFolderPaths()).resolves.toEqual({
|
||||
checkpoints: ['/models']
|
||||
})
|
||||
|
||||
expect(axios.patch).toHaveBeenCalledWith(
|
||||
api.internalURL('/logs/subscribe'),
|
||||
{ enabled: true, clientId: undefined }
|
||||
)
|
||||
})
|
||||
|
||||
it('loads localized template indexes and fuse options defensively', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
data: [{ name: 'template' }],
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: '<html></html>',
|
||||
headers: { 'content-type': 'text/html' }
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('missing locale'))
|
||||
.mockResolvedValueOnce({
|
||||
data: [{ name: 'fallback' }],
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('default missing'))
|
||||
.mockResolvedValueOnce({
|
||||
data: { keys: ['name'] },
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: '<html></html>',
|
||||
headers: { 'content-type': 'text/html' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: { ignored: true },
|
||||
headers: {}
|
||||
})
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => undefined)
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
|
||||
{ name: 'template' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
|
||||
await expect(api.getCoreWorkflowTemplates('zh')).resolves.toEqual([
|
||||
{ name: 'fallback' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates('en')).resolves.toEqual([])
|
||||
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,12 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type {
|
||||
CanvasPointerEvent,
|
||||
Subgraph
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ExecutedWsMessage } from '@/schemas/apiSchema'
|
||||
|
||||
const mockAssert = vi.hoisted(() => vi.fn())
|
||||
|
||||
@@ -14,7 +20,7 @@ const mockNodeOutputStore = vi.hoisted(() => ({
|
||||
}))
|
||||
|
||||
const mockSubgraphNavigationStore = vi.hoisted(() => ({
|
||||
exportState: vi.fn(() => []),
|
||||
exportState: vi.fn((): string[] => []),
|
||||
restoreState: vi.fn()
|
||||
}))
|
||||
|
||||
@@ -23,10 +29,24 @@ const mockWorkflowStore = vi.hoisted(() => ({
|
||||
getWorkflowByPath: vi.fn()
|
||||
}))
|
||||
|
||||
const mockExecutionStore = vi.hoisted(() => ({
|
||||
queuedJobs: {} as Record<string, { workflow: { changeTracker: unknown } }>
|
||||
}))
|
||||
|
||||
const mockMaskEditorIsOpened = vi.hoisted(() => vi.fn(() => false))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
constructor: {
|
||||
maskeditor_is_opended: mockMaskEditorIsOpened
|
||||
},
|
||||
graph: {},
|
||||
ui: {
|
||||
autoQueueEnabled: false,
|
||||
autoQueueMode: 'instant'
|
||||
},
|
||||
rootGraph: {
|
||||
subgraphs: new Map(),
|
||||
serialize: vi.fn(() => ({
|
||||
nodes: [],
|
||||
links: [],
|
||||
@@ -39,8 +59,10 @@ vi.mock('@/scripts/app', () => ({
|
||||
}))
|
||||
},
|
||||
canvas: {
|
||||
ds: { scale: 1, offset: [0, 0] }
|
||||
}
|
||||
ds: { scale: 1, offset: [0, 0] },
|
||||
setGraph: vi.fn()
|
||||
},
|
||||
loadGraphData: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -65,6 +87,10 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: vi.fn(() => mockWorkflowStore)
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/executionStore', () => ({
|
||||
useExecutionStore: vi.fn(() => mockExecutionStore)
|
||||
}))
|
||||
|
||||
import { app } from '@/scripts/app'
|
||||
import { api } from '@/scripts/api'
|
||||
import { ChangeTracker } from '@/scripts/changeTracker'
|
||||
@@ -107,13 +133,120 @@ function mockCanvasState(state: ComfyWorkflowJSON) {
|
||||
vi.mocked(app.rootGraph.serialize).mockReturnValue(state as never)
|
||||
}
|
||||
|
||||
type ListenerMap = Record<string, EventListener[]>
|
||||
|
||||
function storeListener(
|
||||
listeners: ListenerMap,
|
||||
type: string,
|
||||
listener: EventListenerOrEventListenerObject
|
||||
) {
|
||||
if (typeof listener === 'function') {
|
||||
listeners[type] ??= []
|
||||
listeners[type].push(listener)
|
||||
}
|
||||
}
|
||||
|
||||
function dispatchStored(listeners: ListenerMap, type: string, event: Event) {
|
||||
for (const listener of listeners[type] ?? []) {
|
||||
listener(event)
|
||||
}
|
||||
}
|
||||
|
||||
async function flushAsyncFrame() {
|
||||
await Promise.resolve()
|
||||
await Promise.resolve()
|
||||
}
|
||||
|
||||
function getApiListener(name: string) {
|
||||
const call = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find(([eventName]) => eventName === name)
|
||||
expect(call).toBeDefined()
|
||||
return call?.[1] as (event: CustomEvent<ExecutedWsMessage>) => void
|
||||
}
|
||||
|
||||
describe('ChangeTracker', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
nodeIdCounter = 0
|
||||
ChangeTracker.isLoadingGraph = false
|
||||
Reflect.set(ChangeTracker, '_checkStateWarned', false)
|
||||
mockWorkflowStore.activeWorkflow = null
|
||||
mockWorkflowStore.getWorkflowByPath.mockReturnValue(null)
|
||||
mockExecutionStore.queuedJobs = {}
|
||||
mockMaskEditorIsOpened.mockReturnValue(false)
|
||||
app.ui.autoQueueEnabled = false
|
||||
app.ui.autoQueueMode = 'instant'
|
||||
vi.mocked(app.canvas.setGraph).mockClear()
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
app.rootGraph.subgraphs.clear()
|
||||
app.canvas.ds.scale = 1
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
})
|
||||
|
||||
describe('reset', () => {
|
||||
it('updates initialState from activeState or an explicit state', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const changed = createState(2)
|
||||
|
||||
tracker.activeState = changed
|
||||
tracker.reset()
|
||||
|
||||
expect(tracker.initialState).toEqual(changed)
|
||||
expect(tracker.initialState).not.toBe(changed)
|
||||
|
||||
const explicit = createState(3)
|
||||
tracker.reset(explicit)
|
||||
|
||||
expect(tracker.activeState).toEqual(explicit)
|
||||
expect(tracker.activeState).not.toBe(explicit)
|
||||
expect(tracker.initialState).toEqual(explicit)
|
||||
})
|
||||
|
||||
it('does not reset while restoring state', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const original = tracker.initialState
|
||||
tracker._restoringState = true
|
||||
|
||||
tracker.reset(createState(2))
|
||||
|
||||
expect(tracker.initialState).toBe(original)
|
||||
})
|
||||
})
|
||||
|
||||
describe('restore', () => {
|
||||
it('restores viewport, outputs, and root graph navigation', () => {
|
||||
const tracker = createTracker()
|
||||
app.canvas.ds.scale = 2
|
||||
app.canvas.ds.offset = [10, 20]
|
||||
mockNodeOutputStore.snapshotOutputs.mockReturnValue({ 1: { images: [] } })
|
||||
mockSubgraphNavigationStore.exportState.mockReturnValue([])
|
||||
|
||||
tracker.store()
|
||||
app.canvas.ds.scale = 1
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
tracker.restore()
|
||||
|
||||
expect(app.canvas.ds.scale).toBe(2)
|
||||
expect(app.canvas.ds.offset).toEqual([10, 20])
|
||||
expect(mockNodeOutputStore.restoreOutputs).toHaveBeenCalledWith({
|
||||
1: { images: [] }
|
||||
})
|
||||
expect(mockSubgraphNavigationStore.restoreState).toHaveBeenCalledWith([])
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
})
|
||||
|
||||
it('restores saved subgraph navigation when the subgraph exists', () => {
|
||||
const tracker = createTracker()
|
||||
const subgraph = { id: 'subgraph-1' } as unknown as Subgraph
|
||||
app.rootGraph.subgraphs.set('subgraph-1', subgraph)
|
||||
mockSubgraphNavigationStore.exportState.mockReturnValue(['subgraph-1'])
|
||||
|
||||
tracker.store()
|
||||
tracker.restore()
|
||||
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
|
||||
})
|
||||
})
|
||||
|
||||
describe('captureCanvasState', () => {
|
||||
@@ -169,9 +302,32 @@ describe('ChangeTracker', () => {
|
||||
expect.stringContaining('captureCanvasState')
|
||||
)
|
||||
})
|
||||
|
||||
it('reports inactive tracker calls only once for the same workflow', () => {
|
||||
const tracker = createTracker()
|
||||
tracker.workflow.path = '/test/dedupe-workflow.json'
|
||||
mockWorkflowStore.activeWorkflow = { changeTracker: {} }
|
||||
|
||||
tracker.captureCanvasState()
|
||||
tracker.captureCanvasState()
|
||||
|
||||
expect(mockAssert).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('state capture', () => {
|
||||
it('sets the active state without pushing undo when none exists yet', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const changed = createState(2)
|
||||
tracker.activeState = undefined as never
|
||||
mockCanvasState(changed)
|
||||
|
||||
tracker.captureCanvasState()
|
||||
|
||||
expect(tracker.activeState).toEqual(changed)
|
||||
expect(tracker.undoQueue).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('pushes to undoQueue, updates activeState, and calls updateModified', () => {
|
||||
const initial = createState(1)
|
||||
const tracker = createTracker(initial)
|
||||
@@ -238,6 +394,19 @@ describe('ChangeTracker', () => {
|
||||
|
||||
expect(tracker.undoQueue).toHaveLength(ChangeTracker.MAX_HISTORY)
|
||||
})
|
||||
|
||||
it('does not capture until the outer change transaction finishes', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
tracker.beforeChange()
|
||||
tracker.beforeChange()
|
||||
mockCanvasState(createState(2))
|
||||
|
||||
tracker.afterChange()
|
||||
expect(app.rootGraph.serialize).not.toHaveBeenCalled()
|
||||
|
||||
tracker.afterChange()
|
||||
expect(app.rootGraph.serialize).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -302,6 +471,105 @@ describe('ChangeTracker', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('updateModified', () => {
|
||||
it('updates workflow modified state when the store can find it', () => {
|
||||
const state = createState(1)
|
||||
const tracker = createTracker(state)
|
||||
const workflow = { isModified: true }
|
||||
mockWorkflowStore.getWorkflowByPath.mockReturnValue(workflow)
|
||||
|
||||
tracker.updateModified()
|
||||
expect(workflow.isModified).toBe(false)
|
||||
|
||||
tracker.activeState = createState(2)
|
||||
tracker.updateModified()
|
||||
expect(workflow.isModified).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('undo and redo', () => {
|
||||
it('restores previous state and moves the current state to the target queue', async () => {
|
||||
const initial = createState(1)
|
||||
const changed = createState(2)
|
||||
const tracker = createTracker(changed)
|
||||
tracker.undoQueue.push(initial)
|
||||
|
||||
await tracker.undo()
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith(
|
||||
initial,
|
||||
false,
|
||||
false,
|
||||
tracker.workflow,
|
||||
{
|
||||
checkForRerouteMigration: false,
|
||||
silentAssetErrors: true
|
||||
}
|
||||
)
|
||||
expect(tracker.activeState).toBe(initial)
|
||||
expect(tracker.redoQueue).toEqual([changed])
|
||||
expect(tracker._restoringState).toBe(false)
|
||||
})
|
||||
|
||||
it('clears restoring state when loading fails', async () => {
|
||||
const tracker = createTracker(createState(2))
|
||||
tracker.undoQueue.push(createState(1))
|
||||
vi.mocked(app.loadGraphData).mockRejectedValueOnce(
|
||||
new Error('load failed')
|
||||
)
|
||||
|
||||
await expect(tracker.undo()).rejects.toThrow('load failed')
|
||||
|
||||
expect(tracker._restoringState).toBe(false)
|
||||
})
|
||||
|
||||
it('does nothing when no previous state exists', async () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
|
||||
await tracker.undo()
|
||||
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('handles keyboard undo and redo shortcuts', async () => {
|
||||
const tracker = createTracker()
|
||||
const undo = vi.spyOn(tracker, 'undo').mockResolvedValue()
|
||||
const redo = vi.spyOn(tracker, 'redo').mockResolvedValue()
|
||||
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', {
|
||||
key: 'z',
|
||||
ctrlKey: true,
|
||||
shiftKey: true
|
||||
})
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', { key: 'y', metaKey: true })
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', {
|
||||
key: 'z',
|
||||
ctrlKey: true,
|
||||
altKey: true
|
||||
})
|
||||
)
|
||||
).resolves.toBeUndefined()
|
||||
|
||||
expect(undo).toHaveBeenCalledOnce()
|
||||
expect(redo).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('checkState (deprecated)', () => {
|
||||
it('delegates to captureCanvasState', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
@@ -312,5 +580,389 @@ describe('ChangeTracker', () => {
|
||||
|
||||
expect(tracker.activeState).toEqual(changed)
|
||||
})
|
||||
|
||||
it('warns only once before delegating', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
|
||||
tracker.checkState()
|
||||
tracker.checkState()
|
||||
|
||||
expect(warn).toHaveBeenCalledOnce()
|
||||
warn.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('bindInput', () => {
|
||||
it('returns false for missing canvas or body elements', () => {
|
||||
expect(ChangeTracker.bindInput(null)).toBe(false)
|
||||
expect(ChangeTracker.bindInput(document.createElement('canvas'))).toBe(
|
||||
false
|
||||
)
|
||||
expect(ChangeTracker.bindInput(document.body)).toBe(false)
|
||||
})
|
||||
|
||||
it('captures state once when an input-like element changes', () => {
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
const input = document.createElement('input')
|
||||
|
||||
expect(ChangeTracker.bindInput(input)).toBe(true)
|
||||
input.dispatchEvent(new Event('change'))
|
||||
input.dispatchEvent(new Event('change'))
|
||||
|
||||
expect(capture).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('binds textarea-like elements that expose an input handler slot', () => {
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
const element = document.createElement('div') as HTMLElement & {
|
||||
oninput: unknown
|
||||
}
|
||||
element.oninput = null
|
||||
|
||||
expect(ChangeTracker.bindInput(element)).toBe(true)
|
||||
element.dispatchEvent(new Event('change'))
|
||||
|
||||
expect(capture).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('init', () => {
|
||||
it('captures changes from registered browser, graph, and API events', async () => {
|
||||
const windowListeners: ListenerMap = {}
|
||||
const documentListeners: ListenerMap = {}
|
||||
const windowAddSpy = vi
|
||||
.spyOn(window, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(windowListeners, type, listener)
|
||||
})
|
||||
const documentAddSpy = vi
|
||||
.spyOn(document, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(documentListeners, type, listener)
|
||||
})
|
||||
const rafSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
|
||||
const processMouseUp = vi.fn(() => true)
|
||||
const prompt = vi.fn()
|
||||
const close = vi.fn(() => true)
|
||||
const originalProcessMouseUp = LGraphCanvas.prototype.processMouseUp
|
||||
const originalPrompt = LGraphCanvas.prototype.prompt
|
||||
const originalClose = LiteGraph.ContextMenu.prototype.close
|
||||
|
||||
LGraphCanvas.prototype.processMouseUp = processMouseUp
|
||||
LGraphCanvas.prototype.prompt = prompt
|
||||
LiteGraph.ContextMenu.prototype.close = close
|
||||
|
||||
try {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
|
||||
dispatchStored(windowListeners, 'mouseup', new MouseEvent('mouseup'))
|
||||
getApiListener('promptQueued')(
|
||||
new CustomEvent('promptQueued', {
|
||||
detail: {} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
getApiListener('graphCleared')(
|
||||
new CustomEvent('graphCleared', {
|
||||
detail: {} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
dispatchStored(
|
||||
documentListeners,
|
||||
'litegraph:canvas',
|
||||
new CustomEvent('litegraph:canvas', {
|
||||
detail: { subType: 'before-change' }
|
||||
})
|
||||
)
|
||||
dispatchStored(
|
||||
documentListeners,
|
||||
'litegraph:canvas',
|
||||
new CustomEvent('litegraph:canvas', {
|
||||
detail: { subType: 'after-change' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(4)
|
||||
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'Control' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
dispatchStored(windowListeners, 'keyup', new KeyboardEvent('keyup'))
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(5)
|
||||
|
||||
const undoRedo = vi.spyOn(tracker, 'undoRedo').mockResolvedValue(true)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(undoRedo).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(5)
|
||||
|
||||
undoRedo.mockResolvedValue(undefined)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'a' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
const input = document.createElement('input')
|
||||
document.body.append(input)
|
||||
input.focus()
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'b' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
input.remove()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
mockMaskEditorIsOpened.mockReturnValue(true)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'c' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
const canvas = {} as LGraphCanvas
|
||||
LGraphCanvas.prototype.processMouseUp.call(
|
||||
canvas,
|
||||
new MouseEvent('mouseup') as CanvasPointerEvent
|
||||
)
|
||||
|
||||
expect(processMouseUp).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(7)
|
||||
|
||||
const promptCallback = vi.fn()
|
||||
LGraphCanvas.prototype.prompt.call(
|
||||
canvas,
|
||||
'title',
|
||||
'value',
|
||||
promptCallback,
|
||||
new MouseEvent('mouseup') as CanvasPointerEvent
|
||||
)
|
||||
const extendedCallback = prompt.mock.calls[0]?.[2] as
|
||||
| ((value: string) => void)
|
||||
| undefined
|
||||
extendedCallback?.('updated')
|
||||
|
||||
expect(promptCallback).toHaveBeenCalledWith('updated')
|
||||
expect(capture).toHaveBeenCalledTimes(8)
|
||||
|
||||
LiteGraph.ContextMenu.prototype.close.call(
|
||||
{} as InstanceType<typeof LiteGraph.ContextMenu>,
|
||||
new MouseEvent('mouseup')
|
||||
)
|
||||
|
||||
expect(close).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(9)
|
||||
} finally {
|
||||
LGraphCanvas.prototype.processMouseUp = originalProcessMouseUp
|
||||
LGraphCanvas.prototype.prompt = originalPrompt
|
||||
LiteGraph.ContextMenu.prototype.close = originalClose
|
||||
windowAddSpy.mockRestore()
|
||||
documentAddSpy.mockRestore()
|
||||
rafSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('ignores repeat keydowns and missing active trackers', async () => {
|
||||
const windowListeners: ListenerMap = {}
|
||||
const windowAddSpy = vi
|
||||
.spyOn(window, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(windowListeners, type, listener)
|
||||
})
|
||||
const documentAddSpy = vi
|
||||
.spyOn(document, 'addEventListener')
|
||||
.mockImplementation(() => undefined)
|
||||
const rafSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
|
||||
try {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'x', repeat: true })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
expect(capture).not.toHaveBeenCalled()
|
||||
|
||||
mockWorkflowStore.activeWorkflow = null
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'x' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
expect(capture).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
windowAddSpy.mockRestore()
|
||||
documentAddSpy.mockRestore()
|
||||
rafSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('stores executed outputs for the workflow that owns the prompt', () => {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const executed = getApiListener('executed')
|
||||
mockExecutionStore.queuedJobs = {
|
||||
promptA: { workflow: { changeTracker: tracker } }
|
||||
}
|
||||
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
output: { images: ['first'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
merge: true,
|
||||
output: { images: ['second'], text: ['caption'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'missing',
|
||||
node: '2',
|
||||
output: { images: ['ignored'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
|
||||
expect(tracker.nodeOutputs).toEqual({
|
||||
1: { images: ['first', 'second'], text: ['caption'] }
|
||||
})
|
||||
})
|
||||
|
||||
it('replaces non-array executed outputs during merge updates', () => {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const executed = getApiListener('executed')
|
||||
mockExecutionStore.queuedJobs = {
|
||||
promptA: { workflow: { changeTracker: tracker } }
|
||||
}
|
||||
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
output: { value: 'old' }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
merge: true,
|
||||
output: { value: 'new' }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
|
||||
expect(tracker.nodeOutputs).toEqual({
|
||||
1: { value: 'new' }
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('graphEqual', () => {
|
||||
it('compares workflow nodes as an unordered set and ignores extra.ds', () => {
|
||||
const first = createState(2)
|
||||
const second = {
|
||||
...createState(),
|
||||
nodes: [...first.nodes].reverse(),
|
||||
links: first.links,
|
||||
groups: first.groups,
|
||||
extra: { ds: { scale: 2 } }
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, first)).toBe(true)
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for non-object values and meaningful graph differences', () => {
|
||||
const first = createState(1)
|
||||
const differentNodes = createState(2)
|
||||
const differentLinks = {
|
||||
...first,
|
||||
links: [[1, 1, 0, 2, 0, 'MODEL']]
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, null as never)).toBe(false)
|
||||
expect(ChangeTracker.graphEqual(first, differentNodes)).toBe(false)
|
||||
expect(ChangeTracker.graphEqual(first, differentLinks)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for extra properties other than viewport state', () => {
|
||||
const first = createState()
|
||||
const second = {
|
||||
...first,
|
||||
extra: { custom: true }
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
|
||||
})
|
||||
|
||||
it.each([
|
||||
'floatingLinks',
|
||||
'reroutes',
|
||||
'groups',
|
||||
'definitions',
|
||||
'subgraphs'
|
||||
] as const)('returns false when %s differs', (key) => {
|
||||
const first = createState()
|
||||
const second = {
|
||||
...first,
|
||||
[key]: [{ id: 1 }]
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
import { describe, expect, test, vi } from 'vitest'
|
||||
|
||||
import type { LGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { ComponentWidgetImpl, DOMWidgetImpl } from '@/scripts/domWidget'
|
||||
import {
|
||||
addWidget,
|
||||
ComponentWidgetImpl,
|
||||
DOMWidgetImpl,
|
||||
isComponentWidget,
|
||||
isDOMWidget
|
||||
} from '@/scripts/domWidget'
|
||||
|
||||
const { registerWidget, unregisterWidget } = vi.hoisted(() => ({
|
||||
registerWidget: vi.fn(),
|
||||
unregisterWidget: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/domWidgetStore', () => ({
|
||||
useDomWidgetStore: () => ({
|
||||
unregisterWidget: vi.fn()
|
||||
registerWidget,
|
||||
unregisterWidget
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -24,13 +37,11 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Set a specific Y position
|
||||
originalWidget.y = 66
|
||||
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved
|
||||
expect(clonedWidget.y).toBe(66)
|
||||
expect(clonedWidget.node).toBe(newNode)
|
||||
expect(clonedWidget.name).toBe('test-widget')
|
||||
@@ -48,13 +59,11 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Set a specific Y position
|
||||
originalWidget.y = 42
|
||||
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved
|
||||
expect(clonedWidget.y).toBe(42)
|
||||
expect(clonedWidget.node).toBe(newNode)
|
||||
expect(clonedWidget.element).toBe(mockElement)
|
||||
@@ -71,11 +80,9 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Don't explicitly set Y (should be 0 by default)
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved (should be 0)
|
||||
expect(clonedWidget.y).toBe(0)
|
||||
})
|
||||
})
|
||||
@@ -96,3 +103,271 @@ describe('BaseDOMWidgetImpl.isVisible', () => {
|
||||
expect(widget.isVisible()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('DOMWidgetImpl', () => {
|
||||
test('identifies DOM and component widgets', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const domWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'dom',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
const componentWidget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'component',
|
||||
component: { template: '<div />' },
|
||||
inputSpec: { name: 'component', type: 'STRING' },
|
||||
options: {}
|
||||
})
|
||||
|
||||
expect(isDOMWidget(domWidget)).toBe(true)
|
||||
expect(isDOMWidget(componentWidget)).toBe(false)
|
||||
expect(isComponentWidget(componentWidget)).toBe(true)
|
||||
expect(isComponentWidget(domWidget)).toBe(false)
|
||||
})
|
||||
|
||||
test('uses option-backed values, callbacks, and margins', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
let value = 'initial'
|
||||
const setValue = vi.fn((next: string) => {
|
||||
value = next
|
||||
})
|
||||
const callback = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getValue: () => value,
|
||||
setValue,
|
||||
margin: 4
|
||||
}
|
||||
})
|
||||
widget.callback = callback
|
||||
|
||||
widget.value = 'next'
|
||||
|
||||
expect(widget.value).toBe('next')
|
||||
expect(widget.margin).toBe(4)
|
||||
expect(setValue).toHaveBeenCalledWith('next')
|
||||
expect(callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
|
||||
test('uses default value and margin when options do not provide them', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
|
||||
expect(widget.value).toBe('')
|
||||
expect(widget.margin).toBe(10)
|
||||
})
|
||||
|
||||
test('draws zoom placeholders and delegates visible draws', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(true)
|
||||
const onDraw = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
hideOnZoom: true,
|
||||
margin: 5,
|
||||
onDraw
|
||||
}
|
||||
})
|
||||
const ctx = {
|
||||
beginPath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillStyle: '#000',
|
||||
rect: vi.fn()
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
|
||||
widget.draw(ctx, node, 100, 10, 40, true)
|
||||
|
||||
expect(ctx.rect).toHaveBeenCalledWith(5, 15, 90, 30)
|
||||
expect(ctx.fill).toHaveBeenCalledOnce()
|
||||
expect(ctx.fillStyle).toBe('#000')
|
||||
expect(onDraw).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
|
||||
test('skips placeholder drawing when hidden', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(false)
|
||||
const onDraw = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
hideOnZoom: true,
|
||||
onDraw
|
||||
}
|
||||
})
|
||||
const ctx = {
|
||||
beginPath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillStyle: '#000',
|
||||
rect: vi.fn()
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
|
||||
widget.draw(ctx, node, 100, 10, 40, true)
|
||||
|
||||
expect(ctx.rect).not.toHaveBeenCalled()
|
||||
expect(onDraw).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
|
||||
test('computes hidden, option, percent, and fallback layout sizes', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
node.size = [100, 200]
|
||||
const hiddenWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'hidden',
|
||||
type: 'hidden',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
const optionWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'option',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getMinHeight: () => 11,
|
||||
getMaxHeight: () => 88,
|
||||
getHeight: () => 44
|
||||
}
|
||||
})
|
||||
const percentWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'percent',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getMinHeight: () => 10,
|
||||
getMaxHeight: () => 60,
|
||||
getHeight: () => '25%'
|
||||
}
|
||||
})
|
||||
const fallbackWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'fallback',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getHeight: () => 40
|
||||
}
|
||||
})
|
||||
|
||||
expect(hiddenWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 0,
|
||||
maxHeight: 0,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(optionWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 11,
|
||||
maxHeight: 88,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(percentWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 10,
|
||||
maxHeight: 60,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(fallbackWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 40,
|
||||
maxHeight: undefined,
|
||||
minWidth: 0
|
||||
})
|
||||
})
|
||||
|
||||
test('registers widgets immediately and through node lifecycle callbacks', () => {
|
||||
registerWidget.mockClear()
|
||||
unregisterWidget.mockClear()
|
||||
const node = new LGraphNode('test-node')
|
||||
node.graph = {} as LGraph
|
||||
const beforeResize = vi.fn()
|
||||
const afterResize = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
beforeResize,
|
||||
afterResize
|
||||
}
|
||||
})
|
||||
vi.spyOn(node, 'addCustomWidget')
|
||||
|
||||
addWidget(node, widget)
|
||||
node.onAdded?.(node.graph)
|
||||
node.onResize?.([0, 0])
|
||||
node.onRemoved?.()
|
||||
|
||||
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
|
||||
expect(registerWidget).toHaveBeenCalledWith(widget)
|
||||
expect(registerWidget).toHaveBeenCalledTimes(2)
|
||||
expect(beforeResize).toHaveBeenCalledWith(node)
|
||||
expect(afterResize).toHaveBeenCalledWith(node)
|
||||
expect(unregisterWidget).toHaveBeenCalledWith(widget.id)
|
||||
})
|
||||
|
||||
test('computes component layout and serializes raw values', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const value = { nested: true }
|
||||
const widget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'component',
|
||||
component: { template: '<div />' },
|
||||
inputSpec: { name: 'component', type: 'STRING' },
|
||||
options: {
|
||||
getValue: () => value,
|
||||
getMinHeight: () => 12,
|
||||
getMaxHeight: () => 48
|
||||
}
|
||||
})
|
||||
|
||||
expect(widget.computeLayoutSize()).toEqual({
|
||||
minHeight: 12,
|
||||
maxHeight: 48,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(widget.serializeValue()).toEqual({ nested: true })
|
||||
})
|
||||
|
||||
test('adds DOM widgets through LGraphNode prototype helper', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const element = document.createElement('textarea')
|
||||
let value = 'initial'
|
||||
const setValue = vi.fn((next: string) => {
|
||||
value = next
|
||||
})
|
||||
vi.spyOn(node, 'addCustomWidget')
|
||||
|
||||
const widget = node.addDOMWidget('text', 'textarea', element, {
|
||||
getValue: () => value,
|
||||
setValue
|
||||
})
|
||||
const callback = vi.fn()
|
||||
widget.callback = callback
|
||||
widget.value = 'next'
|
||||
|
||||
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
|
||||
expect(widget.element).toBe(element)
|
||||
expect(widget.options.hideOnZoom).toBe(true)
|
||||
expect(widget.value).toBe('next')
|
||||
expect(setValue).toHaveBeenCalledWith('next')
|
||||
expect(callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
})
|
||||
|
||||
106
src/scripts/errorNodeWidgets.test.ts
Normal file
106
src/scripts/errorNodeWidgets.test.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
interface TestWidgetOptions {
|
||||
name: string
|
||||
type: string
|
||||
default?: unknown
|
||||
multiline?: boolean
|
||||
}
|
||||
|
||||
function createWidgetFactory() {
|
||||
return (node: LGraphNode, options: TestWidgetOptions): IBaseWidget => {
|
||||
const widget: IBaseWidget = fromAny({
|
||||
name: options.name,
|
||||
type: options.type,
|
||||
value: options.default,
|
||||
options
|
||||
})
|
||||
node.widgets = [...(node.widgets ?? []), widget]
|
||||
return widget
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
|
||||
() => ({
|
||||
useStringWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
|
||||
() => ({
|
||||
useFloatWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
|
||||
() => ({
|
||||
useBooleanWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
import './errorNodeWidgets'
|
||||
|
||||
describe('errorNodeWidgets', () => {
|
||||
it('restores widgets from serialized values on error nodes', () => {
|
||||
const node = new LGraphNode('BrokenNode')
|
||||
const longText = 'serialized value with more than twenty chars'
|
||||
node.has_errors = true
|
||||
|
||||
node.onConfigure?.(
|
||||
fromAny({
|
||||
widgets_values: {
|
||||
length: 5,
|
||||
0: 'short text',
|
||||
1: longText,
|
||||
2: 12,
|
||||
3: true,
|
||||
4: { nested: 'value' }
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.widgets).toHaveLength(5)
|
||||
expect(node.widgets?.map((widget) => widget.name)).toEqual([
|
||||
'UNKNOWN',
|
||||
'UNKNOWN_1',
|
||||
'UNKNOWN_2',
|
||||
'UNKNOWN_3',
|
||||
'UNKNOWN_4'
|
||||
])
|
||||
expect(node.widgets?.map((widget) => widget.label)).toEqual([
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN'
|
||||
])
|
||||
expect(node.widgets?.map((widget) => widget.value)).toEqual([
|
||||
'short text',
|
||||
longText,
|
||||
12,
|
||||
true,
|
||||
'{"nested":"value"}'
|
||||
])
|
||||
expect(node.serialize_widgets).toBe(true)
|
||||
})
|
||||
|
||||
it('leaves normal nodes unchanged', () => {
|
||||
const node = new LGraphNode('HealthyNode')
|
||||
|
||||
node.onConfigure?.(
|
||||
fromPartial({
|
||||
widgets_values: ['ignored']
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.widgets).toBeUndefined()
|
||||
expect(node.serialize_widgets).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromAvifFile } from './avif'
|
||||
|
||||
@@ -83,6 +84,11 @@ describe('AVIF metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromAvifFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromAvifFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -139,7 +145,8 @@ const buildInfeBox = (
|
||||
itemType: string,
|
||||
version = 2
|
||||
): Uint8Array => {
|
||||
const bodySize = 4 + 2 + 2 + 4 + 1 + 1
|
||||
const itemIdSize = version === 2 ? 2 : version >= 3 ? 4 : 0
|
||||
const bodySize = 4 + itemIdSize + (version >= 2 ? 2 + 4 + 1 + 1 : 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
@@ -147,22 +154,36 @@ const buildInfeBox = (
|
||||
buf.set(new TextEncoder().encode('infe'), 4)
|
||||
buf[8] = version
|
||||
if (version >= 2) {
|
||||
setU16BE(dv, 12, itemId)
|
||||
setU16BE(dv, 14, 0)
|
||||
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), 16)
|
||||
let p = 12
|
||||
if (version === 2) {
|
||||
setU16BE(dv, p, itemId)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, itemId)
|
||||
p += 4
|
||||
}
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), p)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
|
||||
const bodySize = 4 + 2 + infeBoxes.reduce((s, b) => s + b.length, 0)
|
||||
const buildIinfBox = (infeBoxes: Uint8Array[], version = 0): Uint8Array => {
|
||||
const countSize = version === 0 ? 2 : 4
|
||||
const bodySize = 4 + countSize + infeBoxes.reduce((s, b) => s + b.length, 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
setU32BE(dv, 0, totalSize)
|
||||
buf.set(new TextEncoder().encode('iinf'), 4)
|
||||
setU16BE(dv, 12, infeBoxes.length)
|
||||
let off = 14
|
||||
buf[8] = version
|
||||
if (version === 0) {
|
||||
setU16BE(dv, 12, infeBoxes.length)
|
||||
} else {
|
||||
setU32BE(dv, 12, infeBoxes.length)
|
||||
}
|
||||
let off = 8 + 4 + countSize
|
||||
for (const ib of infeBoxes) {
|
||||
buf.set(ib, off)
|
||||
off += ib.length
|
||||
@@ -170,31 +191,91 @@ const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
|
||||
return buf
|
||||
}
|
||||
|
||||
interface IlocItem {
|
||||
itemId: number
|
||||
extentOffset: number
|
||||
extentLength: number
|
||||
extents?: Array<{ extentOffset: number; extentLength: number }>
|
||||
}
|
||||
|
||||
const buildIlocBox = (
|
||||
items: { itemId: number; extentOffset: number; extentLength: number }[]
|
||||
items: IlocItem[],
|
||||
{
|
||||
version = 0,
|
||||
baseOffsetSize = 0,
|
||||
indexSize = 0
|
||||
}: { version?: number; baseOffsetSize?: number; indexSize?: number } = {}
|
||||
): Uint8Array => {
|
||||
const perItemSize = 2 + 2 + 0 + 2 + (4 + 4)
|
||||
const bodySize = 4 + 1 + 1 + 2 + items.length * perItemSize
|
||||
const itemCountSize = version < 2 ? 2 : 4
|
||||
const itemIdSize = version < 2 ? 2 : 4
|
||||
const constructionMethodSize = version === 1 || version === 2 ? 2 : 0
|
||||
const itemSizes = items.map((item) => {
|
||||
const extents = item.extents ?? [
|
||||
{
|
||||
extentOffset: item.extentOffset,
|
||||
extentLength: item.extentLength
|
||||
}
|
||||
]
|
||||
return (
|
||||
itemIdSize +
|
||||
constructionMethodSize +
|
||||
2 +
|
||||
baseOffsetSize +
|
||||
2 +
|
||||
extents.length * (indexSize + 4 + 4)
|
||||
)
|
||||
})
|
||||
const bodySize =
|
||||
4 + 1 + 1 + itemCountSize + itemSizes.reduce((sum, size) => sum + size, 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
setU32BE(dv, 0, totalSize)
|
||||
buf.set(new TextEncoder().encode('iloc'), 4)
|
||||
buf[8] = version
|
||||
buf[12] = 0x44
|
||||
buf[13] = 0x00
|
||||
setU16BE(dv, 14, items.length)
|
||||
let p = 16
|
||||
for (const it of items) {
|
||||
setU16BE(dv, p, it.itemId)
|
||||
buf[13] = (baseOffsetSize << 4) | indexSize
|
||||
let p = 14
|
||||
if (version < 2) {
|
||||
setU16BE(dv, p, items.length)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, items.length)
|
||||
p += 4
|
||||
}
|
||||
for (const it of items) {
|
||||
if (version < 2) {
|
||||
setU16BE(dv, p, it.itemId)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, it.itemId)
|
||||
p += 4
|
||||
}
|
||||
if (version === 1 || version === 2) {
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
}
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
setU16BE(dv, p, 1)
|
||||
if (baseOffsetSize > 0) {
|
||||
setU32BE(dv, p, 0)
|
||||
p += baseOffsetSize
|
||||
}
|
||||
const extents = it.extents ?? [
|
||||
{
|
||||
extentOffset: it.extentOffset,
|
||||
extentLength: it.extentLength
|
||||
}
|
||||
]
|
||||
setU16BE(dv, p, extents.length)
|
||||
p += 2
|
||||
setU32BE(dv, p, it.extentOffset)
|
||||
p += 4
|
||||
setU32BE(dv, p, it.extentLength)
|
||||
p += 4
|
||||
for (const extent of extents) {
|
||||
p += indexSize
|
||||
setU32BE(dv, p, extent.extentOffset)
|
||||
p += 4
|
||||
setU32BE(dv, p, extent.extentLength)
|
||||
p += 4
|
||||
}
|
||||
}
|
||||
return buf
|
||||
}
|
||||
@@ -231,7 +312,13 @@ interface BuildAvifOpts {
|
||||
ftypBrand?: string
|
||||
omitMeta?: boolean
|
||||
omitIloc?: boolean
|
||||
iinfVersion?: number
|
||||
infeVersion?: number
|
||||
ilocVersion?: number
|
||||
ilocBaseOffsetSize?: number
|
||||
ilocIndexSize?: number
|
||||
ilocExtents?: Array<{ extentOffset: number; extentLength: number }>
|
||||
rawExifData?: Uint8Array
|
||||
}
|
||||
|
||||
const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
@@ -242,7 +329,13 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
ftypBrand = 'avif',
|
||||
omitMeta = false,
|
||||
omitIloc = false,
|
||||
infeVersion = 2
|
||||
iinfVersion = 0,
|
||||
infeVersion = 2,
|
||||
ilocVersion = 0,
|
||||
ilocBaseOffsetSize = 0,
|
||||
ilocIndexSize = 0,
|
||||
ilocExtents,
|
||||
rawExifData
|
||||
} = opts
|
||||
|
||||
const ftyp = buildFtypBox(ftypBrand)
|
||||
@@ -250,19 +343,43 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
return ftyp.slice().buffer as ArrayBuffer
|
||||
}
|
||||
|
||||
const exifData = buildExifBlob(exifEntries, endian)
|
||||
const exifData = rawExifData ?? buildExifBlob(exifEntries, endian)
|
||||
const infe = buildInfeBox(1, itemType, infeVersion)
|
||||
const iinf = buildIinfBox([infe])
|
||||
const iinf = buildIinfBox([infe], iinfVersion)
|
||||
|
||||
const realIloc = buildIlocBox([
|
||||
{ itemId: 1, extentOffset: 0, extentLength: exifData.length }
|
||||
])
|
||||
const realIloc = buildIlocBox(
|
||||
[
|
||||
{
|
||||
itemId: 1,
|
||||
extentOffset: 0,
|
||||
extentLength: exifData.length,
|
||||
extents: ilocExtents
|
||||
}
|
||||
],
|
||||
{
|
||||
version: ilocVersion,
|
||||
baseOffsetSize: ilocBaseOffsetSize,
|
||||
indexSize: ilocIndexSize
|
||||
}
|
||||
)
|
||||
const metaSize = 8 + 4 + iinf.length + (omitIloc ? 0 : realIloc.length)
|
||||
const exifOffset = ftyp.length + metaSize
|
||||
|
||||
const finalIloc = buildIlocBox([
|
||||
{ itemId: 1, extentOffset: exifOffset, extentLength: exifData.length }
|
||||
])
|
||||
const finalIloc = buildIlocBox(
|
||||
[
|
||||
{
|
||||
itemId: 1,
|
||||
extentOffset: exifOffset,
|
||||
extentLength: exifData.length,
|
||||
extents: ilocExtents
|
||||
}
|
||||
],
|
||||
{
|
||||
version: ilocVersion,
|
||||
baseOffsetSize: ilocBaseOffsetSize,
|
||||
indexSize: ilocIndexSize
|
||||
}
|
||||
)
|
||||
const finalInner = omitIloc ? [iinf] : [iinf, finalIloc]
|
||||
const meta = buildMetaBox(finalInner)
|
||||
|
||||
@@ -319,6 +436,52 @@ describe('getFromAvifFile', () => {
|
||||
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
|
||||
})
|
||||
|
||||
it('extracts EXIF metadata from versioned item info and location boxes', async () => {
|
||||
const workflow = '{"versioned":true}'
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
exifEntries: [`workflow:${workflow}`],
|
||||
iinfVersion: 1,
|
||||
infeVersion: 3,
|
||||
ilocVersion: 2,
|
||||
ilocBaseOffsetSize: 4,
|
||||
ilocIndexSize: 4
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
|
||||
})
|
||||
|
||||
it('returns {} when the Exif item has no extents', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
exifEntries: ['workflow:{}'],
|
||||
ilocExtents: []
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('returns {} when the Exif payload has no TIFF header', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
rawExifData: new TextEncoder().encode('not tiff data')
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result).toEqual({})
|
||||
expect(console.log).toHaveBeenCalledWith(
|
||||
'Warning: TIFF header not found in EXIF data.'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns {} when AVIF major brand is not "avif"', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({ exifEntries: ['workflow:{}'], ftypBrand: 'heic' })
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromWebmFile } from './ebml'
|
||||
|
||||
@@ -16,6 +17,56 @@ const nanFixturePath = path.resolve(
|
||||
__dirname,
|
||||
'__fixtures__/with_nan_metadata.webm'
|
||||
)
|
||||
const WEBM_SIGNATURE = new Uint8Array([0x1a, 0x45, 0xdf, 0xa3])
|
||||
const SIMPLE_TAG = new Uint8Array([0x67, 0xc8])
|
||||
const TAG_NAME = new Uint8Array([0x45, 0xa3])
|
||||
const TAG_VALUE = new Uint8Array([0x44, 0x87])
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
function bytes(...values: number[]) {
|
||||
return new Uint8Array(values)
|
||||
}
|
||||
|
||||
function vint(value: number) {
|
||||
return bytes(0x80 | value)
|
||||
}
|
||||
|
||||
function element(id: Uint8Array, value: string) {
|
||||
const encoded = encoder.encode(value)
|
||||
return concatBytes(id, vint(encoded.length), encoded)
|
||||
}
|
||||
|
||||
function simpleTag(name: string, value: string, useTwoByteSize = false) {
|
||||
const payload = concatBytes(
|
||||
element(TAG_NAME, name),
|
||||
element(TAG_VALUE, value)
|
||||
)
|
||||
const size = useTwoByteSize
|
||||
? bytes(0x40, payload.length)
|
||||
: vint(payload.length)
|
||||
return concatBytes(SIMPLE_TAG, size, payload)
|
||||
}
|
||||
|
||||
async function readWebm(bytes: Uint8Array) {
|
||||
return getFromWebmFile(
|
||||
new File([bytes as Uint8Array<ArrayBuffer>], 'test.webm', {
|
||||
type: 'video/webm'
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
describe('WebM/EBML metadata', () => {
|
||||
it('extracts workflow and prompt from EBML SimpleTag elements', async () => {
|
||||
@@ -46,6 +97,89 @@ describe('WebM/EBML metadata', () => {
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('extracts plain string tags and trims null-terminated values', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(
|
||||
WEBM_SIGNATURE,
|
||||
simpleTag('\0Comment', ' hello \0ignored', true)
|
||||
)
|
||||
)
|
||||
|
||||
expect(result.comment).toBe('hello')
|
||||
})
|
||||
|
||||
it('ignores prompt tags whose value is not complete JSON', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', '{not-json'))
|
||||
)
|
||||
|
||||
expect(result.prompt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('ignores prompt tags whose value has no JSON object', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', 'not-json'))
|
||||
)
|
||||
|
||||
expect(result.prompt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('parses the first complete JSON object from a prompt tag', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(
|
||||
WEBM_SIGNATURE,
|
||||
simpleTag('PROMPT', 'prefix {"outer":{"inner":1}} trailing')
|
||||
)
|
||||
)
|
||||
|
||||
expect(result.prompt).toEqual({ outer: { inner: 1 } })
|
||||
})
|
||||
|
||||
it('ignores tags whose name has no readable text', async () => {
|
||||
const payload = concatBytes(
|
||||
concatBytes(TAG_NAME, vint(2), bytes(0, 1)),
|
||||
element(TAG_VALUE, 'value')
|
||||
)
|
||||
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
|
||||
)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores tag elements with zero-sized names', async () => {
|
||||
const payload = concatBytes(TAG_NAME, vint(0), element(TAG_VALUE, 'value'))
|
||||
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
|
||||
)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores malformed SimpleTag encodings', async () => {
|
||||
const nameOnly = element(TAG_NAME, 'comment')
|
||||
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x00)))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x7f, 0xff)))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(10), TAG_NAME))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(nameOnly.length), nameOnly)
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -60,5 +194,10 @@ describe('WebM/EBML metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromWebmFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromWebmFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
|
||||
@@ -5,7 +6,8 @@ import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
|
||||
import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getGltfBinaryMetadata } from './gltf'
|
||||
|
||||
@@ -16,11 +18,15 @@ describe('GLTF binary metadata parser', () => {
|
||||
return { header, headerView }
|
||||
}
|
||||
|
||||
const createJSONChunk = (jsonData: ArrayBuffer) => {
|
||||
const createJSONChunk = (
|
||||
jsonData: ArrayBuffer,
|
||||
chunkType = ASCII.JSON,
|
||||
chunkLength = jsonData.byteLength
|
||||
) => {
|
||||
const chunkHeader = new ArrayBuffer(GltfSizeBytes.CHUNK_HEADER)
|
||||
const chunkView = new DataView(chunkHeader)
|
||||
chunkView.setUint32(0, jsonData.byteLength, true)
|
||||
chunkView.setUint32(4, ASCII.JSON, true)
|
||||
chunkView.setUint32(0, chunkLength, true)
|
||||
chunkView.setUint32(4, chunkType, true)
|
||||
return chunkHeader
|
||||
}
|
||||
|
||||
@@ -52,13 +58,27 @@ describe('GLTF binary metadata parser', () => {
|
||||
// Builds a GLB whose JSON chunk is the literal text passed in - used to
|
||||
// embed Python generated bare NaN/Infinity tokens that JSON.stringify
|
||||
// would otherwise coerce to null.
|
||||
function createMockGltfFileFromText(jsonText: string): File {
|
||||
interface MockGltfFileOptions {
|
||||
chunkLength?: number
|
||||
chunkType?: number
|
||||
magicNumber?: number
|
||||
}
|
||||
|
||||
function createMockGltfFileFromText(
|
||||
jsonText: string,
|
||||
{
|
||||
chunkLength,
|
||||
chunkType = ASCII.JSON,
|
||||
magicNumber = ASCII.GLTF
|
||||
}: MockGltfFileOptions = {}
|
||||
): File {
|
||||
const jsonData = new TextEncoder().encode(jsonText)
|
||||
const { header, headerView } = createGLTFFileStructure()
|
||||
|
||||
setHeaders(headerView, jsonData.buffer)
|
||||
setTypeHeader(headerView, magicNumber)
|
||||
|
||||
const chunkHeader = createJSONChunk(jsonData.buffer)
|
||||
const chunkHeader = createJSONChunk(jsonData.buffer, chunkType, chunkLength)
|
||||
|
||||
const fileContent = new Uint8Array(
|
||||
header.byteLength + chunkHeader.byteLength + jsonData.byteLength
|
||||
@@ -130,7 +150,9 @@ describe('GLTF binary metadata parser', () => {
|
||||
expect(metadata).toBeDefined()
|
||||
expect(metadata.prompt).toBeDefined()
|
||||
|
||||
const prompt = metadata.prompt as Record<string, any>
|
||||
const prompt: {
|
||||
node1: { class_type: string; inputs: { seed: number } }
|
||||
} = fromAny(metadata.prompt)
|
||||
expect(prompt.node1.class_type).toBe('TestNode')
|
||||
expect(prompt.node1.inputs.seed).toBe(123456)
|
||||
})
|
||||
@@ -179,6 +201,74 @@ describe('GLTF binary metadata parser', () => {
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the GLB magic number is invalid', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { magicNumber: 0 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the GLB is missing the first chunk header', async () => {
|
||||
const { header, headerView } = createGLTFFileStructure()
|
||||
setTypeHeader(headerView, ASCII.GLTF)
|
||||
setVersionHeader(headerView, 2)
|
||||
setTotalLengthHeader(headerView, GltfSizeBytes.HEADER)
|
||||
const mockFile = new File([header], 'header-only.glb')
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the first chunk is not JSON', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { chunkType: 0 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the declared JSON chunk exceeds the buffer', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { chunkLength: 1024 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the JSON chunk cannot be parsed', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{not json')
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when asset extras are missing', async () => {
|
||||
const mockFile = createMockGltfFile({ asset: { version: '2.0' } })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('skips string metadata values that are not valid JSON', async () => {
|
||||
const mockFile = createMockGltfFile({
|
||||
asset: {
|
||||
version: '2.0',
|
||||
extras: {
|
||||
prompt: '{not json',
|
||||
workflow: '{not json'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -193,5 +283,15 @@ describe('GLTF binary metadata parser', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load result is missing', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader result is not a buffer', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', 'not a buffer')
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromIsobmffFile } from './isobmff'
|
||||
|
||||
@@ -16,6 +17,82 @@ const nanFixturePath = path.resolve(
|
||||
__dirname,
|
||||
'__fixtures__/with_nan_metadata.mp4'
|
||||
)
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
function uint32(value: number) {
|
||||
return new Uint8Array([
|
||||
(value >>> 24) & 0xff,
|
||||
(value >>> 16) & 0xff,
|
||||
(value >>> 8) & 0xff,
|
||||
value & 0xff
|
||||
])
|
||||
}
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
function box(type: string, payload = new Uint8Array(), size?: number) {
|
||||
return concatBytes(
|
||||
uint32(size ?? 8 + payload.length),
|
||||
encoder.encode(type),
|
||||
payload
|
||||
)
|
||||
}
|
||||
|
||||
function keyEntry(name: string) {
|
||||
const encoded = encoder.encode(name)
|
||||
return concatBytes(
|
||||
uint32(8 + encoded.length),
|
||||
encoder.encode('mdta'),
|
||||
encoded
|
||||
)
|
||||
}
|
||||
|
||||
function keysBox(names: string[]) {
|
||||
return box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(names.length), ...names.map(keyEntry))
|
||||
)
|
||||
}
|
||||
|
||||
function dataBox(value: string | Uint8Array) {
|
||||
const payload = typeof value === 'string' ? encoder.encode(value) : value
|
||||
return box('data', concatBytes(uint32(0), uint32(0), payload))
|
||||
}
|
||||
|
||||
function ilstItem(index: number, payload: Uint8Array) {
|
||||
return concatBytes(uint32(8 + payload.length), uint32(index), payload)
|
||||
}
|
||||
|
||||
function ilstBox(...items: Uint8Array[]) {
|
||||
return box('ilst', concatBytes(...items))
|
||||
}
|
||||
|
||||
function metaBox(...children: Uint8Array[]) {
|
||||
return box('meta', concatBytes(uint32(0), ...children))
|
||||
}
|
||||
|
||||
function udtaWithMeta(...children: Uint8Array[]) {
|
||||
return box('udta', metaBox(...children))
|
||||
}
|
||||
|
||||
async function readMp4(bytes: Uint8Array) {
|
||||
return getFromIsobmffFile(
|
||||
new File([bytes as Uint8Array<ArrayBuffer>], 'test.mp4', {
|
||||
type: 'video/mp4'
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
describe('ISOBMFF (MP4) metadata', () => {
|
||||
it('extracts workflow and prompt from QuickTime keys/ilst boxes', async () => {
|
||||
@@ -48,6 +125,102 @@ describe('ISOBMFF (MP4) metadata', () => {
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('extracts metadata from udta nested inside moov', async () => {
|
||||
const bytes = box(
|
||||
'moov',
|
||||
udtaWithMeta(
|
||||
keysBox(['WORKFLOW']),
|
||||
ilstBox(ilstItem(1, dataBox('xxxx{"nodes":[]}')))
|
||||
)
|
||||
)
|
||||
|
||||
const result = await readMp4(bytes)
|
||||
|
||||
expect(result.workflow).toEqual({ nodes: [] })
|
||||
})
|
||||
|
||||
it('returns empty when a top-level box declares an impossible size', async () => {
|
||||
const result = await readMp4(box('free', new Uint8Array([1, 2]), 100))
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the keys box cannot provide entries', async () => {
|
||||
const tooShortKeys = box('keys', uint32(0))
|
||||
const missingKeyEntry = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(8))
|
||||
)
|
||||
const malformedKey = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(7), encoder.encode('bad!'))
|
||||
)
|
||||
const oversizedKey = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(100), encoder.encode('bad!'))
|
||||
)
|
||||
|
||||
await expect(readMp4(udtaWithMeta(tooShortKeys))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(missingKeyEntry))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(malformedKey))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(oversizedKey))).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('ignores item entries whose key is unknown or unsupported', async () => {
|
||||
const unknownIndex = udtaWithMeta(
|
||||
keysBox(['PROMPT']),
|
||||
ilstBox(ilstItem(2, dataBox('{"1":{}}')))
|
||||
)
|
||||
const unsupportedKey = udtaWithMeta(
|
||||
keysBox(['DESCRIPTION']),
|
||||
ilstBox(ilstItem(1, dataBox('{"ignored":true}')))
|
||||
)
|
||||
|
||||
await expect(readMp4(unknownIndex)).resolves.toEqual({})
|
||||
await expect(readMp4(unsupportedKey)).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('ignores metadata items without readable JSON data', async () => {
|
||||
const shortDataBox = box('data')
|
||||
const noJson = dataBox('not-json')
|
||||
const invalidJson = dataBox('prefix {not-json')
|
||||
const noDataBox = box('free', new Uint8Array([1]))
|
||||
const invalidItems = ilstBox(
|
||||
concatBytes(uint32(8), uint32(1)),
|
||||
concatBytes(uint32(100), uint32(1), invalidJson)
|
||||
)
|
||||
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, shortDataBox)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noJson))))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, invalidJson)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noDataBox)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(udtaWithMeta(keysBox(['PROMPT']), invalidItems))
|
||||
).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when required metadata boxes are absent', async () => {
|
||||
await expect(readMp4(box('udta', box('free')))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(ilstBox()))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(keysBox(['PROMPT'])))).resolves.toEqual(
|
||||
{}
|
||||
)
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -63,5 +236,10 @@ describe('ISOBMFF (MP4) metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromIsobmffFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromIsobmffFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
127
src/scripts/metadata/parser.test.ts
Normal file
127
src/scripts/metadata/parser.test.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getFromWebmFile } from '@/scripts/metadata/ebml'
|
||||
import { getGltfBinaryMetadata } from '@/scripts/metadata/gltf'
|
||||
import { getFromIsobmffFile } from '@/scripts/metadata/isobmff'
|
||||
import { getDataFromJSON } from '@/scripts/metadata/json'
|
||||
import { getMp3Metadata } from '@/scripts/metadata/mp3'
|
||||
import { getOggMetadata } from '@/scripts/metadata/ogg'
|
||||
import { getWorkflowDataFromFile } from '@/scripts/metadata/parser'
|
||||
import { getSvgMetadata } from '@/scripts/metadata/svg'
|
||||
import {
|
||||
getAvifMetadata,
|
||||
getFlacMetadata,
|
||||
getLatentMetadata,
|
||||
getPngMetadata,
|
||||
getWebpMetadata
|
||||
} from '@/scripts/pnginfo'
|
||||
|
||||
vi.mock('@/scripts/metadata/ebml', () => ({ getFromWebmFile: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/gltf', () => ({ getGltfBinaryMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/isobmff', () => ({ getFromIsobmffFile: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/json', () => ({ getDataFromJSON: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/mp3', () => ({ getMp3Metadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/ogg', () => ({ getOggMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/svg', () => ({ getSvgMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/pnginfo', () => ({
|
||||
getAvifMetadata: vi.fn(),
|
||||
getFlacMetadata: vi.fn(),
|
||||
getLatentMetadata: vi.fn(),
|
||||
getPngMetadata: vi.fn(),
|
||||
getWebpMetadata: vi.fn()
|
||||
}))
|
||||
|
||||
function file(type: string, name = 'file') {
|
||||
return new File(['data'], name, { type })
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getWorkflowDataFromFile', () => {
|
||||
it('routes png/avif/mp3/ogg/webm to their parsers and returns the result', async () => {
|
||||
vi.mocked(getPngMetadata).mockResolvedValue({ a: 1 } as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/png'))).toEqual({ a: 1 })
|
||||
expect(getPngMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('image/avif'))
|
||||
expect(getAvifMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('audio/mpeg'))
|
||||
expect(getMp3Metadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('audio/ogg'))
|
||||
expect(getOggMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('video/webm'))
|
||||
expect(getFromWebmFile).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('extracts workflow/prompt from webp, preferring lowercase keys', async () => {
|
||||
vi.mocked(getWebpMetadata).mockResolvedValue({
|
||||
workflow: 'wf',
|
||||
prompt: 'pr'
|
||||
} as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
|
||||
workflow: 'wf',
|
||||
prompt: 'pr'
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to capitalized webp keys when lowercase are absent', async () => {
|
||||
vi.mocked(getWebpMetadata).mockResolvedValue({
|
||||
Workflow: 'WF',
|
||||
Prompt: 'PR'
|
||||
} as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
|
||||
workflow: 'WF',
|
||||
prompt: 'PR'
|
||||
})
|
||||
})
|
||||
|
||||
it('handles both flac mime types and extracts workflow/prompt', async () => {
|
||||
vi.mocked(getFlacMetadata).mockResolvedValue({ workflow: 'w' } as never)
|
||||
expect(await getWorkflowDataFromFile(file('audio/flac'))).toEqual({
|
||||
workflow: 'w',
|
||||
prompt: undefined
|
||||
})
|
||||
expect(await getWorkflowDataFromFile(file('audio/x-flac'))).toEqual({
|
||||
workflow: 'w',
|
||||
prompt: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('routes isobmff by mime type and by file extension', async () => {
|
||||
await getWorkflowDataFromFile(file('video/mp4'))
|
||||
await getWorkflowDataFromFile(file('', 'clip.mov'))
|
||||
await getWorkflowDataFromFile(file('', 'clip.m4v'))
|
||||
expect(getFromIsobmffFile).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('routes svg and gltf by mime type or extension', async () => {
|
||||
await getWorkflowDataFromFile(file('image/svg+xml'))
|
||||
await getWorkflowDataFromFile(file('', 'icon.svg'))
|
||||
expect(getSvgMetadata).toHaveBeenCalledTimes(2)
|
||||
|
||||
await getWorkflowDataFromFile(file('model/gltf-binary'))
|
||||
await getWorkflowDataFromFile(file('', 'model.glb'))
|
||||
expect(getGltfBinaryMetadata).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('routes latent/safetensors and json by extension or mime type', async () => {
|
||||
await getWorkflowDataFromFile(file('', 'x.latent'))
|
||||
await getWorkflowDataFromFile(file('', 'x.safetensors'))
|
||||
expect(getLatentMetadata).toHaveBeenCalledTimes(2)
|
||||
|
||||
await getWorkflowDataFromFile(file('application/json'))
|
||||
await getWorkflowDataFromFile(file('', 'x.json'))
|
||||
expect(getDataFromJSON).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('returns undefined for an unrecognized file', async () => {
|
||||
expect(
|
||||
await getWorkflowDataFromFile(file('application/zip', 'a.zip'))
|
||||
).toBe(undefined)
|
||||
})
|
||||
})
|
||||
@@ -2,7 +2,8 @@ import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromPngBuffer, getFromPngFile } from './png'
|
||||
|
||||
@@ -191,6 +192,27 @@ describe('getFromPngBuffer', () => {
|
||||
const result = await getFromPngBuffer(buffer)
|
||||
expect(result['workflow']).toBe(workflow)
|
||||
})
|
||||
|
||||
it('logs error and skips compressed iTXt with invalid deflate data', async () => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const buffer = createPngWithChunk(
|
||||
'iTXt',
|
||||
'workflow',
|
||||
new Uint8Array([1, 2, 3]),
|
||||
{
|
||||
compressionFlag: 1,
|
||||
compressionMethod: 0
|
||||
}
|
||||
)
|
||||
|
||||
const result = await getFromPngBuffer(buffer)
|
||||
|
||||
expect(result['workflow']).toBeUndefined()
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to decompress iTXt chunk "workflow"'),
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFromPngFile', () => {
|
||||
@@ -228,5 +250,12 @@ describe('getFromPngFile', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
await expect(getFromPngFile(file)).rejects.toThrow('FileReader aborted')
|
||||
})
|
||||
|
||||
it('rejects when the FileReader load has no ArrayBuffer result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
await expect(getFromPngFile(file)).rejects.toThrow(
|
||||
'Failed to read file as ArrayBuffer'
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
import type { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import { api } from '@/scripts/api'
|
||||
import { getFromAvifFile } from './metadata/avif'
|
||||
import { getFromFlacFile } from './metadata/flac'
|
||||
import { getFromPngFile } from './metadata/png'
|
||||
import {
|
||||
getAvifMetadata,
|
||||
getFlacMetadata,
|
||||
importA1111,
|
||||
getLatentMetadata,
|
||||
getPngMetadata,
|
||||
getWebpMetadata
|
||||
@@ -56,6 +62,20 @@ function encodeAsciiIfd(entries: AsciiIfdEntry[]): Uint8Array {
|
||||
return buf
|
||||
}
|
||||
|
||||
function encodeNonAsciiIfdEntry(tag: number): Uint8Array {
|
||||
const buf = new Uint8Array(22)
|
||||
const dv = new DataView(buf.buffer)
|
||||
buf.set([0x49, 0x49], 0)
|
||||
dv.setUint16(2, 0x002a, true)
|
||||
dv.setUint32(4, 8, true)
|
||||
dv.setUint16(8, 1, true)
|
||||
dv.setUint16(10, tag, true)
|
||||
dv.setUint16(12, 3, true)
|
||||
dv.setUint32(14, 1, true)
|
||||
dv.setUint32(18, 123, true)
|
||||
return buf
|
||||
}
|
||||
|
||||
type WebpChunk = { type: string; payload: Uint8Array }
|
||||
|
||||
function wrapInWebp(chunks: WebpChunk[]): File {
|
||||
@@ -157,6 +177,16 @@ describe('getWebpMetadata', () => {
|
||||
|
||||
expect(metadata).toEqual({ workflow: '{"a":1}' })
|
||||
})
|
||||
|
||||
it('ignores EXIF entries that are not ASCII strings', async () => {
|
||||
const file = wrapInWebp([
|
||||
{ type: 'EXIF', payload: encodeNonAsciiIfdEntry(270) }
|
||||
])
|
||||
|
||||
const metadata = await getWebpMetadata(file)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getLatentMetadata', () => {
|
||||
@@ -234,3 +264,313 @@ describe('format-specific metadata wrappers', () => {
|
||||
expect(result).toEqual({ workflow: '{"avif":1}' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('importA1111', () => {
|
||||
function widget(
|
||||
name: string,
|
||||
options: string[] = []
|
||||
): IBaseWidget & { value?: string | number } {
|
||||
return fromPartial<IBaseWidget & { value?: string | number }>({
|
||||
name,
|
||||
options: { values: options },
|
||||
value: undefined
|
||||
})
|
||||
}
|
||||
|
||||
function createNode(type: string): LGraphNode {
|
||||
const widgetsByType: Record<string, IBaseWidget[]> = {
|
||||
CheckpointLoaderSimple: [widget('ckpt_name', ['sd15.safetensors'])],
|
||||
CLIPSetLastLayer: [widget('stop_at_clip_layer')],
|
||||
CLIPTextEncode: [widget('text')],
|
||||
EmptyLatentImage: [widget('width'), widget('height')],
|
||||
ImageScale: [widget('width'), widget('height')],
|
||||
ImageUpscaleWithModel: [],
|
||||
KSampler: [
|
||||
widget('cfg'),
|
||||
widget('sampler_name', ['euler_a', 'dpmpp_2m']),
|
||||
widget('scheduler', ['normal', 'karras']),
|
||||
widget('seed'),
|
||||
widget('steps'),
|
||||
widget('denoise')
|
||||
],
|
||||
LatentUpscale: [
|
||||
widget('upscale_method', ['nearest-exact']),
|
||||
widget('width'),
|
||||
widget('height')
|
||||
],
|
||||
LoraLoader: [
|
||||
widget('lora_name', ['foo.safetensors']),
|
||||
widget('strength_model'),
|
||||
widget('strength_clip')
|
||||
],
|
||||
UpscaleModelLoader: [widget('model_name', ['ESRGAN'])],
|
||||
VAEEncodeTiled: [],
|
||||
VAEDecodeTiled: [],
|
||||
SaveImage: [],
|
||||
VAEDecode: []
|
||||
}
|
||||
return {
|
||||
type,
|
||||
widgets: widgetsByType[type] ?? [],
|
||||
connect: vi.fn()
|
||||
} as unknown as LGraphNode
|
||||
}
|
||||
|
||||
function createGraph(): LGraph {
|
||||
return fromPartial<LGraph>({
|
||||
add: vi.fn(),
|
||||
arrange: vi.fn(),
|
||||
clear: vi.fn()
|
||||
})
|
||||
}
|
||||
|
||||
function findWidget(node: LGraphNode, name: string) {
|
||||
return node.widgets?.find((widget) => widget.name === name)
|
||||
}
|
||||
|
||||
it('ignores text without parsed generation settings', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
|
||||
await importA1111(graph, 'positive prompt only')
|
||||
await importA1111(graph, 'positive prompt\nSteps:\n')
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('ignores text without a negative prompt section', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
['positive prompt only', 'Steps: 20, Sampler: Euler a'].join('\n')
|
||||
)
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('stops when a required base node cannot be created', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) =>
|
||||
type === 'KSampler' ? null : createNode(type)
|
||||
)
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'prompt',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Euler a, Size: 512x512'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Failed to create required nodes for A1111 import'
|
||||
)
|
||||
})
|
||||
|
||||
it('builds a basic graph from A1111 parameters', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue(['easynegative'])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'<lora:foo:0.7> portrait easynegative',
|
||||
'Negative prompt: blurry <lora:bad:not-number>',
|
||||
'Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 42, Size: 512x512, Model: sd15, Clip skip: 2, Model hash: ignored'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const checkpoint = nodes.find(
|
||||
(node) => node.type === 'CheckpointLoaderSimple'
|
||||
)
|
||||
const clipSkip = nodes.find((node) => node.type === 'CLIPSetLastLayer')
|
||||
const sampler = nodes.find((node) => node.type === 'KSampler')
|
||||
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
|
||||
const lora = nodes.find((node) => node.type === 'LoraLoader')
|
||||
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
|
||||
|
||||
expect(graph.clear).toHaveBeenCalledOnce()
|
||||
expect(graph.arrange).toHaveBeenCalledOnce()
|
||||
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('sd15.safetensors')
|
||||
expect(findWidget(clipSkip!, 'stop_at_clip_layer')?.value).toBe(-2)
|
||||
expect(findWidget(sampler!, 'cfg')?.value).toBe(7)
|
||||
expect(findWidget(sampler!, 'sampler_name')?.value).toBe('euler_a')
|
||||
expect(findWidget(sampler!, 'scheduler')?.value).toBe('normal')
|
||||
expect(findWidget(sampler!, 'seed')?.value).toBe(42)
|
||||
expect(findWidget(sampler!, 'steps')?.value).toBe(20)
|
||||
expect(findWidget(image!, 'width')?.value).toBe(512)
|
||||
expect(findWidget(image!, 'height')?.value).toBe(512)
|
||||
expect(findWidget(lora!, 'lora_name')?.value).toBe('foo.safetensors')
|
||||
expect(findWidget(lora!, 'strength_model')?.value).toBe(0.7)
|
||||
expect(findWidget(lora!, 'strength_clip')?.value).toBe(0.7)
|
||||
expect(findWidget(textNodes[0], 'text')?.value).toBe(
|
||||
' portrait embedding:easynegative'
|
||||
)
|
||||
expect(findWidget(textNodes[1], 'text')?.value).toBe('blurry ')
|
||||
})
|
||||
|
||||
it('keeps unknown option-prefix values and logs the mismatch', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Unknown Sampler, CFG scale: 7, Seed: 42, Size: 512x512, Model: unknown-model'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const checkpoint = nodes.find(
|
||||
(node) => node.type === 'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('unknown-model')
|
||||
expect(console.warn).toHaveBeenCalledWith(
|
||||
"Unknown value 'unknown-model' for widget 'ckpt_name'",
|
||||
checkpoint
|
||||
)
|
||||
})
|
||||
|
||||
it('skips missing LoraLoader nodes while keeping prompt text cleaned', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
if (type === 'LoraLoader') return null
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'<lora:missing:0.5> portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Euler a, Size: 512x512'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
|
||||
expect(findWidget(textNodes[0], 'text')?.value).toBe(' portrait')
|
||||
})
|
||||
|
||||
it('returns from latent hires setup when LatentUpscale cannot be created', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
if (type === 'LatentUpscale') return null
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 8, Sampler: Euler a, Size: 512x512, Hires upscale: 2, Hires upscaler: Latent'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
expect(nodes.some((node) => node.type === 'KSampler')).toBe(true)
|
||||
expect(nodes.some((node) => node.type === 'LatentUpscale')).toBe(false)
|
||||
})
|
||||
|
||||
it('builds a latent hires pass with explicit resize and denoise settings', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 12, Sampler: DPM++ 2M Karras, CFG scale: 5, Seed: 1, Size: 513x577, Model: sd15, Hires resize: 1025x1089, Hires steps: 4, Hires upscaler: Latent (nearest-exact), Denoising strength: 0.35'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
|
||||
const latentUpscale = nodes.find((node) => node.type === 'LatentUpscale')
|
||||
const samplers = nodes.filter((node) => node.type === 'KSampler')
|
||||
|
||||
expect(findWidget(image!, 'width')?.value).toBe(576)
|
||||
expect(findWidget(image!, 'height')?.value).toBe(640)
|
||||
expect(findWidget(latentUpscale!, 'upscale_method')?.value).toBe(
|
||||
'nearest-exact'
|
||||
)
|
||||
expect(findWidget(latentUpscale!, 'width')?.value).toBe(1088)
|
||||
expect(findWidget(latentUpscale!, 'height')?.value).toBe(1152)
|
||||
expect(findWidget(samplers[0], 'scheduler')?.value).toBe('karras')
|
||||
expect(findWidget(samplers[0], 'sampler_name')?.value).toBe('dpmpp_2m')
|
||||
expect(findWidget(samplers[1], 'steps')?.value).toBe(4)
|
||||
expect(findWidget(samplers[1], 'cfg')?.value).toBe(5)
|
||||
expect(findWidget(samplers[1], 'scheduler')?.value).toBe('karras')
|
||||
expect(findWidget(samplers[1], 'sampler_name')?.value).toBe('dpmpp_2m')
|
||||
expect(findWidget(samplers[1], 'denoise')?.value).toBe(0.35)
|
||||
})
|
||||
|
||||
it('builds an image upscaler hires pass with fallback steps and denoise', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 8, Sampler: Euler a, CFG scale: 6, Seed: 2, Size: 512x512, Model: sd15, Hires upscale: 1.5, Hires upscaler: ESRGAN'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const upscaleLoader = nodes.find(
|
||||
(node) => node.type === 'UpscaleModelLoader'
|
||||
)
|
||||
const imageScale = nodes.find((node) => node.type === 'ImageScale')
|
||||
const samplers = nodes.filter((node) => node.type === 'KSampler')
|
||||
|
||||
expect(findWidget(upscaleLoader!, 'model_name')?.value).toBe('ESRGAN')
|
||||
expect(findWidget(imageScale!, 'width')?.value).toBe(768)
|
||||
expect(findWidget(imageScale!, 'height')?.value).toBe(768)
|
||||
expect(findWidget(samplers[1], 'steps')?.value).toBe(8)
|
||||
expect(findWidget(samplers[1], 'denoise')?.value).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
450
src/scripts/ui.test.ts
Normal file
450
src/scripts/ui.test.ts
Normal file
@@ -0,0 +1,450 @@
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { mockSettingStore } = vi.hoisted(() => ({
|
||||
mockSettingStore: {
|
||||
get: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useRunButtonTelemetry', () => ({
|
||||
useRunButtonTelemetry: () => ({ trackRunButton: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => ({
|
||||
extractWorkflow: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/composables/useSettingsDialog', () => ({
|
||||
useSettingsDialog: () => ({ show: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => mockSettingStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackWorkflowExecution: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ resetView: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({ execute: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/workspaceStore', () => ({
|
||||
useWorkspaceStore: () => ({ focusMode: false })
|
||||
}))
|
||||
|
||||
vi.mock('./api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn(),
|
||||
clearItems: vi.fn(),
|
||||
deleteItem: vi.fn(),
|
||||
dispatchCustomEvent: vi.fn(),
|
||||
getHistory: vi.fn(),
|
||||
getJobDetail: vi.fn(),
|
||||
getQueue: vi.fn(),
|
||||
interrupt: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./app', () => ({
|
||||
app: {
|
||||
clean: vi.fn(),
|
||||
handleFile: vi.fn(),
|
||||
loadGraphData: vi.fn(),
|
||||
openClipspace: vi.fn(),
|
||||
queuePrompt: vi.fn(),
|
||||
refreshComboInNodes: vi.fn(() => Promise.resolve())
|
||||
},
|
||||
ComfyApp: class ComfyApp {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/dialog', () => ({
|
||||
ComfyDialog: class ComfyDialog {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/settings', () => ({
|
||||
ComfySettingsDialog: class ComfySettingsDialog {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/toggleSwitch', () => ({
|
||||
toggleSwitch: () => document.createElement('div')
|
||||
}))
|
||||
|
||||
import { extractWorkflow } from '@/platform/remote/comfyui/jobs/fetchJobs'
|
||||
|
||||
import { api } from './api'
|
||||
import { app } from './app'
|
||||
import { $el, ComfyUI } from './ui'
|
||||
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
localStorage.clear()
|
||||
vi.clearAllMocks()
|
||||
mockSettingStore.get.mockReturnValue(undefined)
|
||||
Object.assign(app, {
|
||||
lastExecutionError: undefined,
|
||||
nodeOutputs: undefined
|
||||
})
|
||||
})
|
||||
|
||||
async function click(button: HTMLButtonElement) {
|
||||
const handler = button.onclick
|
||||
expect(handler).toBeTypeOf('function')
|
||||
await Promise.resolve(handler?.call(button, fromAny(new MouseEvent('click'))))
|
||||
}
|
||||
|
||||
function buttonByText(root: ParentNode, text: string): HTMLButtonElement {
|
||||
const button = [...root.querySelectorAll('button')].find(
|
||||
(candidate) => candidate.textContent === text
|
||||
)
|
||||
if (!button) throw new Error(`Missing button: ${text}`)
|
||||
return button
|
||||
}
|
||||
|
||||
describe('$el', () => {
|
||||
it('creates elements with classes, children, props, and callbacks', () => {
|
||||
const parent = document.createElement('section')
|
||||
const child = document.createElement('span')
|
||||
const callback = vi.fn()
|
||||
|
||||
const element = $el(
|
||||
'label.primary.secondary',
|
||||
{
|
||||
parent,
|
||||
$: callback,
|
||||
dataset: { role: 'name' },
|
||||
for: 'target-input',
|
||||
style: { display: 'block' },
|
||||
title: 'Label'
|
||||
},
|
||||
child
|
||||
)
|
||||
|
||||
expect(element.tagName).toBe('LABEL')
|
||||
expect(element.classList.contains('primary')).toBe(true)
|
||||
expect(element.classList.contains('secondary')).toBe(true)
|
||||
expect(element.dataset.role).toBe('name')
|
||||
expect(element.getAttribute('for')).toBe('target-input')
|
||||
expect(element.style.display).toBe('block')
|
||||
expect(element.title).toBe('Label')
|
||||
expect(element.firstElementChild).toBe(child)
|
||||
expect(parent.firstElementChild).toBe(element)
|
||||
expect(callback).toHaveBeenCalledWith(element)
|
||||
})
|
||||
|
||||
it('accepts string and single-element children shorthands', () => {
|
||||
const textElement = $el('button', 'Run')
|
||||
const child = document.createElement('strong')
|
||||
const wrapper = $el('div', child)
|
||||
|
||||
expect(textElement.textContent).toBe('Run')
|
||||
expect(wrapper.firstElementChild).toBe(child)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyUI legacy menu', () => {
|
||||
it('loads queue items and runs list actions', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: [{ id: 'pending', priority: 2 }]
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue({
|
||||
outputs: { node: { images: ['image.png'] } }
|
||||
} as never)
|
||||
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(api.getJobDetail).toHaveBeenCalledWith('running')
|
||||
expect(extractWorkflow).toHaveBeenCalled()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
|
||||
expect(app.nodeOutputs).toEqual({ node: { images: ['image.png'] } })
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Cancel'))
|
||||
expect(api.interrupt).toHaveBeenCalledWith('running')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Delete'))
|
||||
expect(api.deleteItem).toHaveBeenCalledWith('queue', 'pending')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Clear Queue'))
|
||||
expect(api.clearItems).toHaveBeenCalledWith('queue')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Refresh'))
|
||||
expect(api.getQueue).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips loading queue items when job details are unavailable', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: []
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue(null as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(extractWorkflow).not.toHaveBeenCalled()
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('loads queue item workflows without outputs', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: []
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue({ id: 'running' } as never)
|
||||
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
|
||||
expect(app.nodeOutputs).toBeUndefined()
|
||||
})
|
||||
|
||||
it('loads history in reverse order', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValue([
|
||||
{ id: 'old', priority: 1 },
|
||||
{ id: 'new', priority: 2 }
|
||||
] as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.history.show()
|
||||
|
||||
expect(ui.history.element.textContent).toContain('2: LoadDelete')
|
||||
expect(ui.history.element.textContent?.indexOf('2:')).toBeLessThan(
|
||||
ui.history.element.textContent?.indexOf('1:') ?? Number.MAX_SAFE_INTEGER
|
||||
)
|
||||
})
|
||||
|
||||
it('updates queue status and auto-queues when enabled', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'instant'
|
||||
ui.lastQueueSize = 1
|
||||
ui.batchCount = 3
|
||||
|
||||
ui.setStatus({ exec_info: { queue_remaining: 0 } })
|
||||
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 3)
|
||||
expect(ui.queueSize.textContent).toBe('Queue size: 0')
|
||||
expect(ui.lastQueueSize).toBe(3)
|
||||
|
||||
ui.setStatus(null)
|
||||
expect(ui.queueSize.textContent).toBe('Queue size: ERR')
|
||||
})
|
||||
|
||||
it('does not auto-queue while a prior execution error is present', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'instant'
|
||||
ui.lastQueueSize = 1
|
||||
Object.assign(app, { lastExecutionError: new Error('failed') })
|
||||
|
||||
ui.setStatus({ exec_info: { queue_remaining: 0 } })
|
||||
|
||||
expect(app.queuePrompt).not.toHaveBeenCalled()
|
||||
expect(ui.lastQueueSize).toBe(0)
|
||||
})
|
||||
|
||||
it('tracks graph changes for change-mode auto queueing', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
const graphChanged = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find(([eventName]) => eventName === 'graphChanged')?.[1]
|
||||
if (!graphChanged) throw new Error('Missing graphChanged listener')
|
||||
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'change'
|
||||
ui.lastQueueSize = 1
|
||||
graphChanged(fromAny(new CustomEvent('graphChanged')))
|
||||
|
||||
expect(ui.graphHasChanged).toBe(true)
|
||||
|
||||
ui.lastQueueSize = 0
|
||||
graphChanged(fromAny(new CustomEvent('graphChanged')))
|
||||
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
|
||||
expect(ui.graphHasChanged).toBe(false)
|
||||
})
|
||||
|
||||
it('wires primary menu buttons to app and command actions', async () => {
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'Queue Prompt'))
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
|
||||
|
||||
await click(buttonByText(document, 'Queue Front'))
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(-1, 1)
|
||||
|
||||
await click(buttonByText(document, 'Save'))
|
||||
await click(buttonByText(document, 'Save (API Format)'))
|
||||
await click(buttonByText(document, 'Refresh'))
|
||||
await click(buttonByText(document, 'Clipspace'))
|
||||
await click(buttonByText(document, 'Clear'))
|
||||
await click(buttonByText(document, 'Load Default'))
|
||||
await click(buttonByText(document, 'Reset View'))
|
||||
|
||||
expect(app.refreshComboInNodes).toHaveBeenCalled()
|
||||
expect(app.openClipspace).toHaveBeenCalled()
|
||||
expect(app.clean).toHaveBeenCalled()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith()
|
||||
expect(ui.menuContainer.style.display).toBe('none')
|
||||
})
|
||||
|
||||
it('wires file input and legacy option controls', async () => {
|
||||
const ui = new ComfyUI(app)
|
||||
const file = new File(['{}'], 'workflow.json', {
|
||||
type: 'application/json'
|
||||
})
|
||||
const fileInput = document.getElementById(
|
||||
'comfy-file-input'
|
||||
) as HTMLInputElement
|
||||
Object.defineProperty(fileInput, 'files', {
|
||||
configurable: true,
|
||||
value: [file]
|
||||
})
|
||||
|
||||
await Promise.resolve(
|
||||
fileInput.onchange?.call(fileInput, new Event('change'))
|
||||
)
|
||||
expect(app.handleFile).toHaveBeenCalledWith(file, 'file_button')
|
||||
expect(fileInput.value).toBe('')
|
||||
|
||||
const range = document.getElementById(
|
||||
'batchCountInputRange'
|
||||
) as HTMLInputElement
|
||||
const number = document.getElementById(
|
||||
'batchCountInputNumber'
|
||||
) as HTMLInputElement
|
||||
range.value = '4'
|
||||
const extraOptionsCheckbox = document.querySelector(
|
||||
'label input[type="checkbox"]'
|
||||
) as HTMLInputElement
|
||||
extraOptionsCheckbox.checked = true
|
||||
extraOptionsCheckbox.onchange?.call(
|
||||
extraOptionsCheckbox,
|
||||
fromPartial({ srcElement: extraOptionsCheckbox })
|
||||
)
|
||||
expect(ui.batchCount).toBe(4)
|
||||
expect(document.getElementById('extraOptions')?.style.display).toBe('block')
|
||||
|
||||
number.value = '7'
|
||||
number.oninput?.call(number, fromPartial({ target: number }))
|
||||
expect(range.value).toBe('7')
|
||||
|
||||
range.value = '9'
|
||||
range.oninput?.call(range, fromPartial({ srcElement: range }))
|
||||
expect(number.value).toBe('9')
|
||||
|
||||
const autoQueueCheckbox = document.getElementById(
|
||||
'autoQueueCheckbox'
|
||||
) as HTMLInputElement
|
||||
autoQueueCheckbox.checked = true
|
||||
autoQueueCheckbox.onchange?.call(
|
||||
autoQueueCheckbox,
|
||||
fromPartial({ target: autoQueueCheckbox })
|
||||
)
|
||||
expect(ui.autoQueueEnabled).toBe(true)
|
||||
|
||||
extraOptionsCheckbox.checked = false
|
||||
extraOptionsCheckbox.onchange?.call(
|
||||
extraOptionsCheckbox,
|
||||
fromPartial({ srcElement: extraOptionsCheckbox })
|
||||
)
|
||||
expect(ui.batchCount).toBe(1)
|
||||
expect(ui.autoQueueEnabled).toBe(false)
|
||||
expect(document.getElementById('extraOptions')?.style.display).toBe('none')
|
||||
})
|
||||
|
||||
it('toggles queue visibility through the menu button', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [],
|
||||
Pending: []
|
||||
} as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'View Queue'))
|
||||
expect(ui.queue.visible).toBe(true)
|
||||
|
||||
await click(buttonByText(document, 'Close'))
|
||||
expect(ui.queue.visible).toBe(false)
|
||||
})
|
||||
|
||||
it('does not clear or load defaults when confirmation is declined', async () => {
|
||||
mockSettingStore.get.mockReturnValue(true)
|
||||
vi.stubGlobal(
|
||||
'confirm',
|
||||
vi.fn(() => false)
|
||||
)
|
||||
try {
|
||||
new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'Clear'))
|
||||
await click(buttonByText(document, 'Load Default'))
|
||||
|
||||
expect(app.clean).not.toHaveBeenCalled()
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
vi.unstubAllGlobals()
|
||||
}
|
||||
})
|
||||
|
||||
it('persists manual menu dragging', () => {
|
||||
Object.defineProperty(document.body, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 1000
|
||||
})
|
||||
Object.defineProperty(document.body, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: 800
|
||||
})
|
||||
const ui = new ComfyUI(app)
|
||||
ui.menuContainer.style.display = 'block'
|
||||
Object.defineProperty(ui.menuContainer, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 100
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: 80
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'offsetLeft', {
|
||||
configurable: true,
|
||||
get: () => 700
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'offsetTop', {
|
||||
configurable: true,
|
||||
get: () => 20
|
||||
})
|
||||
const handle = ui.menuContainer.querySelector('.drag-handle') as HTMLElement
|
||||
|
||||
handle.onmousedown?.(
|
||||
new MouseEvent('mousedown', { clientX: 10, clientY: 10 })
|
||||
)
|
||||
document.onmousemove?.(
|
||||
new MouseEvent('mousemove', { clientX: 20, clientY: 30 })
|
||||
)
|
||||
document.onmouseup?.(new MouseEvent('mouseup'))
|
||||
|
||||
expect(ui.menuContainer.classList.contains('comfy-menu-manual-pos')).toBe(
|
||||
true
|
||||
)
|
||||
expect(ui.menuContainer.style.right).toBe('190px')
|
||||
expect(localStorage.getItem('Comfy.MenuPosition')).toBe(
|
||||
JSON.stringify({ x: 700, y: 20 })
|
||||
)
|
||||
expect(document.onmousemove).toBeNull()
|
||||
expect(document.onmouseup).toBeNull()
|
||||
})
|
||||
})
|
||||
199
src/scripts/ui/components/button.test.ts
Normal file
199
src/scripts/ui/components/button.test.ts
Normal file
@@ -0,0 +1,199 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyApp } from '@/scripts/app'
|
||||
|
||||
vi.mock('../../ui', () => ({
|
||||
$el: (tag: string, props?: Record<string, unknown>, children?: Node[]) => {
|
||||
const [tagName, ...classes] = tag.split('.')
|
||||
const element = document.createElement(tagName)
|
||||
if (classes.length) element.classList.add(...classes)
|
||||
if (props) {
|
||||
const listeners = Object.entries(props).filter(([key]) =>
|
||||
key.startsWith('on')
|
||||
)
|
||||
for (const [key, listener] of listeners) {
|
||||
if (typeof listener === 'function') {
|
||||
element.addEventListener(key.slice(2), listener as EventListener)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (children) element.append(...children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
prop: <T>(
|
||||
target: object,
|
||||
name: string,
|
||||
defaultValue: T,
|
||||
onChanged?: (currentValue: T, previousValue: T) => void
|
||||
) => {
|
||||
let currentValue: T
|
||||
Object.defineProperty(target, name, {
|
||||
get() {
|
||||
return currentValue
|
||||
},
|
||||
set(newValue: T) {
|
||||
const previousValue = currentValue
|
||||
currentValue = newValue
|
||||
onChanged?.(currentValue, previousValue)
|
||||
}
|
||||
})
|
||||
;(target as Record<string, T>)[name] = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
}))
|
||||
|
||||
import { ComfyButton } from './button'
|
||||
|
||||
class MockPopup extends EventTarget {
|
||||
element = document.createElement('div')
|
||||
open = false
|
||||
toggle = vi.fn(() => {
|
||||
this.open = !this.open
|
||||
this.dispatchEvent(new CustomEvent('change'))
|
||||
})
|
||||
}
|
||||
|
||||
function mockApp(settingValue: boolean) {
|
||||
let listener: (() => void) | undefined
|
||||
const app = {
|
||||
ui: {
|
||||
settings: {
|
||||
getSettingValue: vi.fn(() => settingValue),
|
||||
addEventListener: vi.fn((_event: string, callback: () => void) => {
|
||||
listener = callback
|
||||
})
|
||||
}
|
||||
}
|
||||
} as unknown as ComfyApp
|
||||
return {
|
||||
app,
|
||||
setSettingValue(value: boolean) {
|
||||
settingValue = value
|
||||
listener?.()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
describe('ComfyButton', () => {
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('renders icon, content, tooltip, enabled state, and click action', () => {
|
||||
const action = vi.fn()
|
||||
const button = new ComfyButton({
|
||||
icon: 'play',
|
||||
overIcon: 'pause',
|
||||
iconSize: 18,
|
||||
content: 'Run',
|
||||
tooltip: 'Queue prompt',
|
||||
enabled: false,
|
||||
classList: { primary: true, hiddenClass: false },
|
||||
action
|
||||
})
|
||||
|
||||
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
|
||||
expect(button.contentElement.textContent).toBe('Run')
|
||||
expect(button.element.title).toBe('Queue prompt')
|
||||
expect(button.element.getAttribute('aria-label')).toBe('Queue prompt')
|
||||
expect(button.element.classList.contains('primary')).toBe(true)
|
||||
expect(button.element.classList.contains('disabled')).toBe(true)
|
||||
expect((button.element as HTMLButtonElement).disabled).toBe(true)
|
||||
|
||||
button.enabled = true
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
expect(button.iconElement.className).toBe('mdi mdi-pause mdi-18px')
|
||||
button.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
expect(action).toHaveBeenCalledWith(expect.any(MouseEvent), button)
|
||||
})
|
||||
|
||||
it('supports HTMLElement content and removing tooltip text', () => {
|
||||
const button = new ComfyButton({ content: 'Text', tooltip: 'Hint' })
|
||||
const content = document.createElement('strong')
|
||||
content.textContent = 'Element'
|
||||
|
||||
button.content = content
|
||||
button.tooltip = ''
|
||||
|
||||
expect(button.contentElement.firstElementChild).toBe(content)
|
||||
expect(button.element.hasAttribute('title')).toBe(false)
|
||||
})
|
||||
|
||||
it('updates the hover icon when overIcon changes while hovered', () => {
|
||||
const button = new ComfyButton({ icon: 'play' })
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.overIcon = 'pause'
|
||||
|
||||
expect(button.iconElement.className).toBe('mdi mdi-pause')
|
||||
})
|
||||
|
||||
it('hides and shows from a visibility setting', () => {
|
||||
const settings = mockApp(false)
|
||||
const button = new ComfyButton({
|
||||
app: settings.app,
|
||||
visibilitySetting: {
|
||||
id: 'Comfy.UseNewMenu',
|
||||
showValue: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(button.hidden).toBe(true)
|
||||
expect(button.element.classList.contains('hidden')).toBe(true)
|
||||
|
||||
settings.setSettingValue(true)
|
||||
|
||||
expect(button.hidden).toBe(false)
|
||||
expect(button.element.classList.contains('hidden')).toBe(false)
|
||||
})
|
||||
|
||||
it('toggles click popups and reflects popup open state in classes', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(fromAny(popup))
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
|
||||
expect(popup.toggle).toHaveBeenCalledOnce()
|
||||
expect(button.element.classList.contains('popup-open')).toBe(true)
|
||||
|
||||
popup.toggle()
|
||||
|
||||
expect(button.element.classList.contains('popup-closed')).toBe(true)
|
||||
})
|
||||
|
||||
it('opens hover popups while either the button or popup is hovered', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(
|
||||
fromAny(popup),
|
||||
'hover'
|
||||
)
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
expect(popup.open).toBe(true)
|
||||
popup.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(popup.open).toBe(true)
|
||||
popup.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(popup.open).toBe(false)
|
||||
})
|
||||
|
||||
it('does not click-toggle a hover popup while hovered', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(
|
||||
fromAny(popup),
|
||||
'hover'
|
||||
)
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
|
||||
expect(popup.toggle).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
284
src/scripts/ui/components/popup.test.ts
Normal file
284
src/scripts/ui/components/popup.test.ts
Normal file
@@ -0,0 +1,284 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
type ElChild = Node | string
|
||||
type ElInput = Record<string, unknown> | ElChild | ElChild[]
|
||||
|
||||
function appendChildren(element: HTMLElement, children: ElChild | ElChild[]) {
|
||||
const list = Array.isArray(children) ? children : [children]
|
||||
for (const child of list) {
|
||||
element.append(child)
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock('../../ui', () => ({
|
||||
$el: (tag: string, propsOrChildren?: ElInput, children?: ElChild[]) => {
|
||||
const [tagName, ...classes] = tag.split('.')
|
||||
const element = document.createElement(tagName)
|
||||
if (classes.length) element.classList.add(...classes)
|
||||
|
||||
if (
|
||||
propsOrChildren instanceof Node ||
|
||||
typeof propsOrChildren === 'string' ||
|
||||
Array.isArray(propsOrChildren)
|
||||
) {
|
||||
appendChildren(element, propsOrChildren)
|
||||
} else if (propsOrChildren) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
if (key === '$' && typeof value === 'function') {
|
||||
value(element)
|
||||
} else if (key === 'parent' && value instanceof HTMLElement) {
|
||||
value.append(element)
|
||||
} else if (key === 'textContent') {
|
||||
element.textContent = String(value)
|
||||
} else if (key === 'ariaLabel') {
|
||||
element.setAttribute('aria-label', String(value))
|
||||
} else if (key === 'ariaHasPopup') {
|
||||
element.setAttribute('aria-haspopup', String(value))
|
||||
} else if (key === 'type') {
|
||||
element.setAttribute('type', String(value))
|
||||
} else if (
|
||||
key.toLowerCase().startsWith('on') &&
|
||||
typeof value === 'function'
|
||||
) {
|
||||
element.addEventListener(
|
||||
key.slice(2).toLowerCase(),
|
||||
value as EventListener
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (children) appendChildren(element, children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
prop: <T>(
|
||||
target: object,
|
||||
name: string,
|
||||
defaultValue: T,
|
||||
onChanged?: (currentValue: T, previousValue: T) => void
|
||||
) => {
|
||||
let currentValue: T
|
||||
Object.defineProperty(target, name, {
|
||||
get() {
|
||||
return currentValue
|
||||
},
|
||||
set(newValue: T) {
|
||||
const previousValue = currentValue
|
||||
currentValue = newValue
|
||||
onChanged?.(currentValue, previousValue)
|
||||
}
|
||||
})
|
||||
;(target as Record<string, T>)[name] = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../utils', () => ({
|
||||
applyClasses: (
|
||||
element: HTMLElement,
|
||||
classList: string | Record<string, boolean>,
|
||||
...baseClasses: string[]
|
||||
) => {
|
||||
element.className = baseClasses.join(' ')
|
||||
if (typeof classList === 'string') {
|
||||
element.classList.add(...classList.split(' ').filter(Boolean))
|
||||
} else {
|
||||
for (const [className, enabled] of Object.entries(classList)) {
|
||||
element.classList.toggle(className, enabled)
|
||||
}
|
||||
}
|
||||
},
|
||||
toggleElement:
|
||||
<T>(
|
||||
element: HTMLElement,
|
||||
{
|
||||
onShow
|
||||
}: {
|
||||
onShow?: (element: HTMLElement, value: T) => void
|
||||
} = {}
|
||||
) =>
|
||||
(value: T) => {
|
||||
element.hidden = !value
|
||||
if (value) onShow?.(element, value)
|
||||
}
|
||||
}))
|
||||
|
||||
import { ComfyAsyncDialog } from './asyncDialog'
|
||||
import { ComfyButton } from './button'
|
||||
import { ComfyButtonGroup } from './buttonGroup'
|
||||
import { ComfyPopup } from './popup'
|
||||
import { ComfySplitButton } from './splitButton'
|
||||
|
||||
function targetWithRect(rect: DOMRect) {
|
||||
const target = document.createElement('button')
|
||||
vi.spyOn(target, 'getBoundingClientRect').mockReturnValue(rect)
|
||||
document.body.append(target)
|
||||
return target
|
||||
}
|
||||
|
||||
describe('ComfyPopup and related UI components', () => {
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('opens, positions, updates children and classes, and closes from escape', () => {
|
||||
const target = targetWithRect(
|
||||
DOMRect.fromRect({ x: 10, y: 20, width: 80, height: 30 })
|
||||
)
|
||||
const child = document.createElement('span')
|
||||
const popup = new ComfyPopup(
|
||||
{ target, classList: { menu: true, hidden: false } },
|
||||
child
|
||||
)
|
||||
const open = vi.fn()
|
||||
const close = vi.fn()
|
||||
const change = vi.fn()
|
||||
popup.addEventListener('open', open)
|
||||
popup.addEventListener('close', close)
|
||||
popup.addEventListener('change', change)
|
||||
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
|
||||
DOMRect.fromRect({ height: 20 })
|
||||
)
|
||||
|
||||
popup.open = true
|
||||
|
||||
expect(open).toHaveBeenCalledOnce()
|
||||
expect(change).toHaveBeenCalledOnce()
|
||||
expect(popup.element).toHaveClass('open')
|
||||
expect(popup.element.style.getPropertyValue('--left')).toBe('10px')
|
||||
expect(popup.element.style.getPropertyValue('--bottom')).toBe('35px')
|
||||
|
||||
const nextChild = document.createElement('strong')
|
||||
popup.children = [nextChild]
|
||||
popup.classList = 'extra'
|
||||
|
||||
expect(popup.element.firstElementChild).toBe(nextChild)
|
||||
expect(popup.element).toHaveClass('comfyui-popup', 'left', 'extra')
|
||||
|
||||
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Enter' }))
|
||||
expect(popup.open).toBe(true)
|
||||
|
||||
const escape = new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
cancelable: true
|
||||
})
|
||||
window.dispatchEvent(escape)
|
||||
|
||||
expect(close).toHaveBeenCalledOnce()
|
||||
expect(escape.defaultPrevented).toBe(true)
|
||||
expect(popup.open).toBe(false)
|
||||
})
|
||||
|
||||
it('handles outside clicks, target clicks, and relative right positioning', () => {
|
||||
const target = targetWithRect(
|
||||
DOMRect.fromRect({ x: 100, y: 40, width: 60, height: 25 })
|
||||
)
|
||||
const container = document.createElement('section')
|
||||
document.body.append(container)
|
||||
const popup = new ComfyPopup({
|
||||
target,
|
||||
container,
|
||||
position: 'relative',
|
||||
horizontal: 'right'
|
||||
})
|
||||
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
|
||||
DOMRect.fromRect({ height: 100 })
|
||||
)
|
||||
Object.defineProperty(popup.element, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 40
|
||||
})
|
||||
|
||||
popup.open = true
|
||||
target.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
expect(popup.open).toBe(true)
|
||||
|
||||
const outside = document.createElement('div')
|
||||
document.body.append(outside)
|
||||
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
|
||||
expect(popup.open).toBe(false)
|
||||
expect(popup.element.style.getPropertyValue('--left')).toBe('0px')
|
||||
expect(popup.element.style.getPropertyValue('--top')).toBe('25px')
|
||||
})
|
||||
|
||||
it('keeps outside clicks open when target clicks are not ignored', () => {
|
||||
const target = targetWithRect(DOMRect.fromRect({ height: 10 }))
|
||||
const popup = new ComfyPopup({
|
||||
target,
|
||||
ignoreTarget: false,
|
||||
closeOnEscape: false
|
||||
})
|
||||
const outside = document.createElement('div')
|
||||
document.body.append(outside)
|
||||
|
||||
popup.toggle()
|
||||
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' }))
|
||||
|
||||
expect(popup.open).toBe(true)
|
||||
})
|
||||
|
||||
it('renders split button popup items and updates button groups', () => {
|
||||
const primary = new ComfyButton({ content: 'Queue' })
|
||||
const itemButton = new ComfyButton({ content: 'Queue front' })
|
||||
const rawItem = document.createElement('button')
|
||||
rawItem.textContent = 'Queue back'
|
||||
const split = new ComfySplitButton(
|
||||
{
|
||||
primary,
|
||||
mode: 'hover',
|
||||
horizontal: 'right',
|
||||
position: 'absolute'
|
||||
},
|
||||
itemButton,
|
||||
rawItem
|
||||
)
|
||||
|
||||
expect(split.element).toHaveClass('comfyui-split-button', 'hover')
|
||||
expect(split.popup.element).toHaveTextContent('Queue front')
|
||||
expect(split.popup.element).toHaveTextContent('Queue back')
|
||||
expect(document.body).toContainElement(split.popup.element)
|
||||
|
||||
const group = new ComfyButtonGroup(primary)
|
||||
group.append(itemButton)
|
||||
group.insert(fromAny(rawItem), 1)
|
||||
|
||||
expect(group.element.children).toHaveLength(3)
|
||||
expect(group.remove(itemButton)).toEqual([itemButton])
|
||||
expect(group.remove(itemButton)).toBeUndefined()
|
||||
expect(group.element.children).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('resolves async dialogs from buttons, close events, and prompt actions', async () => {
|
||||
const dialog = new ComfyAsyncDialog<number>([
|
||||
{ text: 'Seven', value: 7 },
|
||||
'Fallback'
|
||||
])
|
||||
|
||||
const promise = dialog.show('Pick one')
|
||||
dialog.element.querySelector<HTMLButtonElement>('button')?.click()
|
||||
await expect(promise).resolves.toBe(7)
|
||||
|
||||
const closePromise = dialog.show(document.createElement('em'))
|
||||
dialog.element.dispatchEvent(new Event('close'))
|
||||
await expect(closePromise).resolves.toBeNull()
|
||||
|
||||
const promptPromise = ComfyAsyncDialog.prompt({
|
||||
title: 'Confirm',
|
||||
message: 'Continue?',
|
||||
actions: [{ text: 'Yes', value: 'yes' }]
|
||||
})
|
||||
const prompt = Array.from(document.querySelectorAll('dialog')).at(-1)
|
||||
expect(prompt).toHaveTextContent('Confirm')
|
||||
prompt?.querySelector<HTMLButtonElement>('button')?.click()
|
||||
|
||||
await expect(promptPromise).resolves.toBe('yes')
|
||||
expect(document.body).not.toContainElement(prompt ?? null)
|
||||
})
|
||||
})
|
||||
69
src/scripts/ui/dialog.test.ts
Normal file
69
src/scripts/ui/dialog.test.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ComfyDialog } from './dialog'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
$el: (
|
||||
selector: string,
|
||||
propsOrChildren?: Record<string, unknown> | Node[],
|
||||
maybeChildren?: Node[]
|
||||
) => {
|
||||
const [tag, ...classes] = selector.split('.')
|
||||
const element = document.createElement(tag || 'div')
|
||||
element.classList.add(...classes.filter(Boolean))
|
||||
const children = Array.isArray(propsOrChildren)
|
||||
? propsOrChildren
|
||||
: maybeChildren
|
||||
|
||||
if (propsOrChildren && !Array.isArray(propsOrChildren)) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
if (key === 'parent' && value instanceof Node) {
|
||||
value.appendChild(element)
|
||||
} else if (key === '$' && typeof value === 'function') {
|
||||
value(element)
|
||||
} else {
|
||||
Reflect.set(element, key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
element.append(...(children ?? []))
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
describe('ComfyDialog', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('shows string and element content and closes through the default button', () => {
|
||||
const dialog = new ComfyDialog()
|
||||
|
||||
dialog.show('<strong>Hello</strong>')
|
||||
|
||||
expect(dialog.element.style.display).toBe('flex')
|
||||
expect(dialog.textElement.innerHTML).toBe('<strong>Hello</strong>')
|
||||
|
||||
dialog.element.querySelector('button')?.click()
|
||||
expect(dialog.element.style.display).toBe('none')
|
||||
|
||||
const first = document.createElement('span')
|
||||
const second = document.createElement('em')
|
||||
dialog.show([first, second])
|
||||
|
||||
expect([...dialog.textElement.children]).toEqual([first, second])
|
||||
})
|
||||
|
||||
it('uses supplied custom buttons', () => {
|
||||
const button = document.createElement('button')
|
||||
const dialog = new ComfyDialog('section', [button])
|
||||
|
||||
expect(dialog.element.tagName).toBe('SECTION')
|
||||
expect(dialog.element.querySelector('button')).toBe(button)
|
||||
})
|
||||
})
|
||||
216
src/scripts/ui/draggableList.test.ts
Normal file
216
src/scripts/ui/draggableList.test.ts
Normal file
@@ -0,0 +1,216 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DraggableList } from './draggableList'
|
||||
|
||||
function createList(itemCount: number) {
|
||||
const container = document.createElement('div')
|
||||
const items = Array.from({ length: itemCount }, (_, index) => {
|
||||
const item = document.createElement('div')
|
||||
item.className = 'item'
|
||||
item.dataset.index = String(index)
|
||||
|
||||
const handle = document.createElement('button')
|
||||
handle.className = 'drag-handle'
|
||||
item.append(handle)
|
||||
container.append(item)
|
||||
|
||||
return item
|
||||
})
|
||||
document.body.append(container)
|
||||
return { container, items }
|
||||
}
|
||||
|
||||
function setRect(element: Element, top: number, height = 20) {
|
||||
return vi
|
||||
.spyOn(element, 'getBoundingClientRect')
|
||||
.mockReturnValue(new DOMRect(0, top, 100, height))
|
||||
}
|
||||
|
||||
function defineScrollMetrics(
|
||||
container: HTMLElement,
|
||||
scrollHeight: number,
|
||||
clientHeight: number
|
||||
) {
|
||||
Object.defineProperty(container, 'scrollHeight', {
|
||||
configurable: true,
|
||||
value: scrollHeight
|
||||
})
|
||||
Object.defineProperty(container, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: clientHeight
|
||||
})
|
||||
}
|
||||
|
||||
function mouseDragEvent(
|
||||
target: Element,
|
||||
overrides: Partial<MouseEvent> = {}
|
||||
): MouseEvent {
|
||||
return {
|
||||
button: 0,
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target,
|
||||
...overrides
|
||||
} satisfies Partial<MouseEvent> as MouseEvent
|
||||
}
|
||||
|
||||
describe('DraggableList', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('ignores missing containers and non-primary drag starts', () => {
|
||||
const listWithoutContainer = new DraggableList(null, '.item')
|
||||
const { container, items } = createList(1)
|
||||
const list = new DraggableList(container, '.item')
|
||||
|
||||
list.dragStart(
|
||||
mouseDragEvent(items[0].querySelector('.drag-handle')!, { button: 1 })
|
||||
)
|
||||
list.dragEnd()
|
||||
|
||||
expect(listWithoutContainer.listContainer).toBeNull()
|
||||
expect(list.draggableItem).toBeUndefined()
|
||||
})
|
||||
|
||||
it('starts from a handle, scrolls downward, and reorders upward', () => {
|
||||
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
const { container, items } = createList(3)
|
||||
const scrollBy = vi.fn((_left: number, top: number) => {
|
||||
container.scrollTop += top
|
||||
})
|
||||
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
|
||||
defineScrollMetrics(container, 120, 80)
|
||||
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
|
||||
new DOMRect(0, 0, 100, 80)
|
||||
)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[1], 30)
|
||||
setRect(items[2], 60)
|
||||
|
||||
const list = new DraggableList(container, '.item')
|
||||
const dragStart = vi.fn()
|
||||
const dragEnd = vi.fn()
|
||||
list.addEventListener('dragstart', dragStart)
|
||||
list.addEventListener('dragend', dragEnd)
|
||||
|
||||
list.dragStart(
|
||||
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
|
||||
clientX: 10,
|
||||
clientY: 70
|
||||
})
|
||||
)
|
||||
list.drag(
|
||||
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
|
||||
clientX: 20,
|
||||
clientY: 100
|
||||
})
|
||||
)
|
||||
items[1].dataset.isToggled = ''
|
||||
list.dragEnd()
|
||||
|
||||
expect(dragStart).toHaveBeenCalledOnce()
|
||||
expect(dragEnd).toHaveBeenCalledOnce()
|
||||
expect(scrollBy).toHaveBeenCalledWith(0, 10)
|
||||
expect([...container.children]).toEqual([items[0], items[2], items[1]])
|
||||
expect(items[0].classList.contains('is-idle')).toBe(true)
|
||||
expect(items[1].classList.contains('is-idle')).toBe(true)
|
||||
})
|
||||
|
||||
it('supports touch coordinates, upward scrolling, and downward reorder', () => {
|
||||
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
const { container, items } = createList(3)
|
||||
const scrollBy = vi.fn((_left: number, top: number) => {
|
||||
container.scrollTop += top
|
||||
})
|
||||
container.scrollTop = 10
|
||||
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
|
||||
defineScrollMetrics(container, 120, 80)
|
||||
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
|
||||
new DOMRect(0, 20, 100, 80)
|
||||
)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[1], 30)
|
||||
setRect(items[2], 60)
|
||||
|
||||
const list = new DraggableList(container, '.item')
|
||||
const touchStart = {
|
||||
button: 0,
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target: items[0].querySelector('.drag-handle')!,
|
||||
touches: [{ clientX: 5, clientY: 30 }]
|
||||
} as unknown as TouchEvent
|
||||
const touchMove = {
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target: items[0].querySelector('.drag-handle')!,
|
||||
touches: [{ clientX: 8, clientY: 10 }]
|
||||
} as unknown as TouchEvent
|
||||
|
||||
list.dragStart(touchStart)
|
||||
list.drag(touchMove)
|
||||
items[1].dataset.isToggled = ''
|
||||
list.dragEnd()
|
||||
|
||||
expect(scrollBy).toHaveBeenCalledWith(0, -10)
|
||||
expect([...container.children]).toEqual([items[1], items[0], items[2]])
|
||||
})
|
||||
|
||||
it('updates idle item state around the dragged item midpoint', () => {
|
||||
const { container, items } = createList(3)
|
||||
const list = new DraggableList(container, '.item')
|
||||
const state = list as unknown as {
|
||||
items: HTMLElement[]
|
||||
draggableItem: HTMLElement
|
||||
}
|
||||
state.items = items
|
||||
state.draggableItem = items[1]
|
||||
list.itemsGap = 5
|
||||
items[0].classList.add('is-idle')
|
||||
items[1].classList.add('is-idle')
|
||||
items[2].classList.add('is-idle')
|
||||
items[0].dataset.isAbove = ''
|
||||
const draggedRect = setRect(items[1], -10)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[2], 60)
|
||||
|
||||
list.updateIdleItemsStateAndPosition()
|
||||
|
||||
expect(items[0].dataset.isToggled).toBe('')
|
||||
expect(items[0].style.transform).toBe('translateY(25px)')
|
||||
expect(items[2].style.transform).toBe('')
|
||||
|
||||
draggedRect.mockReturnValue(new DOMRect(0, 100, 100, 20))
|
||||
list.updateIdleItemsStateAndPosition()
|
||||
|
||||
expect(items[0].dataset.isToggled).toBeUndefined()
|
||||
expect(items[2].dataset.isToggled).toBe('')
|
||||
expect(items[2].style.transform).toBe('translateY(-25px)')
|
||||
})
|
||||
|
||||
it('uses zero gap for short lists and disposes listeners', () => {
|
||||
const { container } = createList(1)
|
||||
const list = new DraggableList(container, '.item')
|
||||
const off = vi.fn()
|
||||
const disposableList = list as unknown as { off: Array<() => void> }
|
||||
disposableList.off = [off]
|
||||
|
||||
list.setItemsGap()
|
||||
list.dispose()
|
||||
|
||||
expect(list.itemsGap).toBe(0)
|
||||
expect(off).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
33
src/scripts/ui/imagePreview.test.ts
Normal file
33
src/scripts/ui/imagePreview.test.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { calculateImageGrid } from './imagePreview'
|
||||
|
||||
function createImage(width: number, height: number) {
|
||||
const img = document.createElement('img')
|
||||
Object.defineProperty(img, 'naturalWidth', {
|
||||
configurable: true,
|
||||
value: width
|
||||
})
|
||||
Object.defineProperty(img, 'naturalHeight', {
|
||||
configurable: true,
|
||||
value: height
|
||||
})
|
||||
return img
|
||||
}
|
||||
|
||||
describe('imagePreview', () => {
|
||||
it('calculates the highest-area grid', () => {
|
||||
const images = [
|
||||
createImage(100, 100),
|
||||
createImage(100, 100),
|
||||
createImage(100, 100)
|
||||
]
|
||||
|
||||
expect(calculateImageGrid(images, 300, 120)).toMatchObject({
|
||||
cellWidth: 100,
|
||||
cellHeight: 100,
|
||||
cols: 3,
|
||||
rows: 1
|
||||
})
|
||||
})
|
||||
})
|
||||
86
src/scripts/ui/toggleSwitch.test.ts
Normal file
86
src/scripts/ui/toggleSwitch.test.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { toggleSwitch } from './toggleSwitch'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
$el: (
|
||||
selector: string,
|
||||
propsOrChildren?: Record<string, unknown> | Node[] | Node,
|
||||
maybeChildren?: Node[] | Node
|
||||
) => {
|
||||
const [tag, ...classes] = selector.split('.')
|
||||
const element = document.createElement(tag || 'div')
|
||||
element.classList.add(...classes.filter(Boolean))
|
||||
const children = Array.isArray(propsOrChildren)
|
||||
? propsOrChildren
|
||||
: propsOrChildren instanceof Node
|
||||
? [propsOrChildren]
|
||||
: Array.isArray(maybeChildren)
|
||||
? maybeChildren
|
||||
: maybeChildren instanceof Node
|
||||
? [maybeChildren]
|
||||
: []
|
||||
|
||||
if (
|
||||
propsOrChildren &&
|
||||
!(propsOrChildren instanceof Node) &&
|
||||
!Array.isArray(propsOrChildren)
|
||||
) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
Reflect.set(element, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
element.append(...children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
describe('toggleSwitch', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('selects the first item when none is preselected', () => {
|
||||
const onChange = vi.fn()
|
||||
const container = toggleSwitch('mode', ['first', 'second'], { onChange })
|
||||
const labels = [...container.querySelectorAll('label')]
|
||||
const inputs = [...container.querySelectorAll('input')]
|
||||
|
||||
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(true)
|
||||
expect((inputs[0] as HTMLInputElement).checked).toBe(true)
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
item: 'first',
|
||||
prev: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('moves selection and reports the previous item', () => {
|
||||
const onChange = vi.fn()
|
||||
const container = toggleSwitch(
|
||||
'mode',
|
||||
[
|
||||
{ text: 'first', tooltip: 'First option' },
|
||||
{ text: 'second', value: '2' }
|
||||
],
|
||||
{ onChange }
|
||||
)
|
||||
const labels = [...container.querySelectorAll('label')]
|
||||
const secondInput = labels[1].querySelector('input') as HTMLInputElement
|
||||
|
||||
secondInput.onchange?.(new Event('change'))
|
||||
|
||||
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(false)
|
||||
expect(labels[1].classList.contains('comfy-toggle-selected')).toBe(true)
|
||||
expect(labels[0].title).toBe('First option')
|
||||
expect(secondInput.value).toBe('2')
|
||||
expect(onChange).toHaveBeenLastCalledWith({
|
||||
item: { text: 'second', value: '2' },
|
||||
prev: { text: 'first', tooltip: 'First option', value: 'first' }
|
||||
})
|
||||
})
|
||||
})
|
||||
45
src/scripts/ui/utils.test.ts
Normal file
45
src/scripts/ui/utils.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { applyClasses, toggleElement } from './utils'
|
||||
|
||||
describe('ui utils', () => {
|
||||
it('applies string, array, object, and required classes', () => {
|
||||
const element = document.createElement('div')
|
||||
|
||||
applyClasses(element, 'one two', 'required')
|
||||
expect([...element.classList]).toEqual(['one', 'two', 'required'])
|
||||
|
||||
applyClasses(element, ['three', 'four'])
|
||||
expect([...element.classList]).toEqual(['three', 'four'])
|
||||
|
||||
applyClasses(element, { five: true, six: false, seven: true })
|
||||
expect([...element.classList]).toEqual(['five', 'seven'])
|
||||
|
||||
applyClasses(element, null as unknown as string)
|
||||
expect(element.className).toBe('')
|
||||
})
|
||||
|
||||
it('toggles an element through a placeholder', () => {
|
||||
const parent = document.createElement('div')
|
||||
const element = document.createElement('span')
|
||||
const onHide = vi.fn()
|
||||
const onShow = vi.fn()
|
||||
parent.append(element)
|
||||
const toggle = toggleElement(element, { onHide, onShow })
|
||||
|
||||
toggle(false)
|
||||
expect(parent.firstChild).toBeInstanceOf(Comment)
|
||||
expect(onHide).toHaveBeenCalledWith(element)
|
||||
|
||||
toggle(true)
|
||||
expect(parent.firstChild).toBe(element)
|
||||
expect(onShow).toHaveBeenCalledWith(element, true)
|
||||
|
||||
toggle('visible')
|
||||
expect(onShow).toHaveBeenCalledWith(element, 'visible')
|
||||
|
||||
toggle(false)
|
||||
expect(parent.firstChild).toBeInstanceOf(Comment)
|
||||
expect(onHide).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
127
src/scripts/utils.test.ts
Normal file
127
src/scripts/utils.test.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
addStylesheet,
|
||||
clone,
|
||||
getStorageValue,
|
||||
prop,
|
||||
setStorageValue
|
||||
} from './utils'
|
||||
|
||||
interface LinkAttrs {
|
||||
href: string
|
||||
onerror: (error: Event) => void
|
||||
onload: () => void
|
||||
parent: HTMLElement
|
||||
rel: string
|
||||
type: string
|
||||
}
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
api: {
|
||||
clientId: null as string | null,
|
||||
initialClientId: null as string | null
|
||||
},
|
||||
$el: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('./api', () => ({
|
||||
api: mocks.api
|
||||
}))
|
||||
|
||||
vi.mock('./ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
function lastLinkAttrs() {
|
||||
return mocks.$el.mock.calls.at(-1)?.[1] as LinkAttrs
|
||||
}
|
||||
|
||||
describe('scripts utils', () => {
|
||||
afterEach(() => {
|
||||
localStorage.clear()
|
||||
sessionStorage.clear()
|
||||
mocks.api.clientId = null
|
||||
mocks.api.initialClientId = null
|
||||
mocks.$el.mockReset()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('clones with structuredClone and falls back to JSON cloning', () => {
|
||||
const source = { nested: { value: 1 } }
|
||||
|
||||
expect(clone(source)).toEqual(source)
|
||||
|
||||
vi.stubGlobal(
|
||||
'structuredClone',
|
||||
vi.fn(() => {
|
||||
throw new Error('unsupported')
|
||||
})
|
||||
)
|
||||
|
||||
const cloned = clone(source)
|
||||
cloned.nested.value = 2
|
||||
|
||||
expect(cloned).toEqual({ nested: { value: 2 } })
|
||||
expect(source).toEqual({ nested: { value: 1 } })
|
||||
})
|
||||
|
||||
it('adds stylesheets from script and relative URLs', async () => {
|
||||
const scriptPromise = addStylesheet('/extensions/example.js')
|
||||
lastLinkAttrs().onload()
|
||||
|
||||
await expect(scriptPromise).resolves.toBeUndefined()
|
||||
expect(lastLinkAttrs()).toMatchObject({
|
||||
href: '/extensions/example.css',
|
||||
parent: document.head,
|
||||
rel: 'stylesheet',
|
||||
type: 'text/css'
|
||||
})
|
||||
|
||||
const cssPromise = addStylesheet('theme.css', 'https://example.com/base/')
|
||||
lastLinkAttrs().onload()
|
||||
|
||||
await expect(cssPromise).resolves.toBeUndefined()
|
||||
expect(lastLinkAttrs().href).toBe('https://example.com/base/theme.css')
|
||||
})
|
||||
|
||||
it('rejects when stylesheet loading fails', async () => {
|
||||
const promise = addStylesheet('missing.css', 'https://example.com/')
|
||||
const error = new Event('error')
|
||||
lastLinkAttrs().onerror(error)
|
||||
|
||||
await expect(promise).rejects.toBe(error)
|
||||
})
|
||||
|
||||
it('defines an observable property with the supplied default', () => {
|
||||
const target = {}
|
||||
const onChanged = vi.fn()
|
||||
|
||||
expect(prop(target, 'mode', 'initial', onChanged)).toBe('initial')
|
||||
Object.assign(target, { mode: 'next' })
|
||||
|
||||
expect((target as { mode: string }).mode).toBe('next')
|
||||
expect(onChanged).toHaveBeenCalledWith('next', undefined, target, 'mode')
|
||||
})
|
||||
|
||||
it('uses client-scoped storage before local fallback', () => {
|
||||
mocks.api.clientId = 'client-1'
|
||||
setStorageValue('setting', 'client-value')
|
||||
sessionStorage.removeItem('setting:client-1')
|
||||
|
||||
expect(getStorageValue('setting')).toBe('client-value')
|
||||
expect(localStorage.getItem('setting')).toBe('client-value')
|
||||
|
||||
sessionStorage.setItem('setting:client-1', 'session-value')
|
||||
|
||||
expect(getStorageValue('setting')).toBe('session-value')
|
||||
})
|
||||
|
||||
it('uses initial client id when the current client id is unavailable', () => {
|
||||
mocks.api.initialClientId = 'initial-1'
|
||||
setStorageValue('setting', 'initial-value')
|
||||
|
||||
expect(sessionStorage.getItem('setting:initial-1')).toBe('initial-value')
|
||||
expect(getStorageValue('setting')).toBe('initial-value')
|
||||
})
|
||||
})
|
||||
356
src/scripts/widgets.test.ts
Normal file
356
src/scripts/widgets.test.ts
Normal file
@@ -0,0 +1,356 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type {
|
||||
IBaseWidget,
|
||||
IComboWidget
|
||||
} from '@/lib/litegraph/src/types/widgets'
|
||||
import type { InputSpec } from '@/schemas/nodeDefSchema'
|
||||
|
||||
const mockSettingGet = vi.hoisted(() => vi.fn())
|
||||
const mockNextValueForLinkedTarget = vi.hoisted(() => vi.fn())
|
||||
const mockIsComboWidget = vi.hoisted(() => vi.fn())
|
||||
const mockTransformInputSpecV1ToV2 = vi.hoisted(() => vi.fn())
|
||||
|
||||
function v2WidgetConstructor(kind: string) {
|
||||
return () => (_node: LGraphNode, inputSpec: { name: string }) => ({
|
||||
name: `${kind}:${inputSpec.name}`,
|
||||
options: { minNodeSize: [20, 30] }
|
||||
})
|
||||
}
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => `translated:${key}`
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: mockSettingGet
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/litegraph/src/litegraph', () => ({
|
||||
isComboWidget: mockIsComboWidget
|
||||
}))
|
||||
|
||||
vi.mock('./valueControl', () => ({
|
||||
nextValueForLinkedTarget: mockNextValueForLinkedTarget
|
||||
}))
|
||||
|
||||
vi.mock('@/schemas/nodeDef/migration', () => ({
|
||||
transformInputSpecV1ToV2: mockTransformInputSpecV1ToV2
|
||||
}))
|
||||
|
||||
vi.mock('@/core/graph/widgets/dynamicWidgets', () => ({
|
||||
dynamicWidgets: {
|
||||
DYNAMIC: () => ({
|
||||
widget: { name: 'dynamic', options: {} },
|
||||
minWidth: 1,
|
||||
minHeight: 1
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
|
||||
() => ({ useBooleanWidget: v2WidgetConstructor('BOOLEAN') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget',
|
||||
() => ({ useBoundingBoxWidget: v2WidgetConstructor('BOUNDING_BOX') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useCurveWidget',
|
||||
() => ({ useCurveWidget: v2WidgetConstructor('CURVE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useChartWidget',
|
||||
() => ({ useChartWidget: v2WidgetConstructor('CHART') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useColorWidget',
|
||||
() => ({ useColorWidget: v2WidgetConstructor('COLOR') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useComboWidget',
|
||||
() => ({ useComboWidget: v2WidgetConstructor('COMBO') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
|
||||
() => ({ useFloatWidget: v2WidgetConstructor('FLOAT') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useGalleriaWidget',
|
||||
() => ({ useGalleriaWidget: v2WidgetConstructor('GALLERIA') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxesWidget',
|
||||
() => ({ useBoundingBoxesWidget: v2WidgetConstructor('BOUNDING_BOXES') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useColorsWidget',
|
||||
() => ({ useColorsWidget: v2WidgetConstructor('COLORS') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useImageCompareWidget',
|
||||
() => ({ useImageCompareWidget: v2WidgetConstructor('IMAGECOMPARE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useImageUploadWidget',
|
||||
() => ({
|
||||
useImageUploadWidget: () => (_node: LGraphNode, inputName: string) => ({
|
||||
widget: { name: `IMAGEUPLOAD:${inputName}`, options: {} },
|
||||
minWidth: 5,
|
||||
minHeight: 6
|
||||
})
|
||||
})
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useIntWidget',
|
||||
() => ({ useIntWidget: v2WidgetConstructor('INT') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useMarkdownWidget',
|
||||
() => ({ useMarkdownWidget: v2WidgetConstructor('MARKDOWN') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/usePainterWidget',
|
||||
() => ({ usePainterWidget: v2WidgetConstructor('PAINTER') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useRangeWidget',
|
||||
() => ({ useRangeWidget: v2WidgetConstructor('RANGE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
|
||||
() => ({ useStringWidget: v2WidgetConstructor('STRING') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useTextareaWidget',
|
||||
() => ({ useTextareaWidget: v2WidgetConstructor('TEXTAREA') })
|
||||
)
|
||||
|
||||
vi.mock('./domWidget', () => ({}))
|
||||
vi.mock('./errorNodeWidgets', () => ({}))
|
||||
|
||||
import {
|
||||
ComfyWidgets,
|
||||
IS_CONTROL_WIDGET,
|
||||
addValueControlWidget,
|
||||
addValueControlWidgets,
|
||||
isValidWidgetType,
|
||||
updateControlWidgetLabel
|
||||
} from './widgets'
|
||||
|
||||
// `linkedWidgets`, `beforeQueued`, and `afterQueued` already exist on
|
||||
// IBaseWidget (via the litegraph augmentation), so no extra members needed.
|
||||
type MockWidget = IBaseWidget
|
||||
|
||||
function makeTargetWidget(overrides: Partial<MockWidget> = {}): MockWidget {
|
||||
return {
|
||||
name: 'seed',
|
||||
value: 1,
|
||||
callback: vi.fn(),
|
||||
options: {},
|
||||
linkedWidgets: [],
|
||||
computedDisabled: false,
|
||||
...overrides
|
||||
} as MockWidget
|
||||
}
|
||||
|
||||
function makeNode(inputs: LGraphNode['inputs'] = []) {
|
||||
const widgets: MockWidget[] = []
|
||||
const node = {
|
||||
id: 42,
|
||||
inputs,
|
||||
addWidget: vi.fn(
|
||||
(
|
||||
type: string,
|
||||
name: string,
|
||||
value: string,
|
||||
callback: () => void,
|
||||
options: Record<string, unknown>
|
||||
) => {
|
||||
const widget: MockWidget = fromAny({
|
||||
type,
|
||||
name,
|
||||
value,
|
||||
callback,
|
||||
options,
|
||||
linkedWidgets: [],
|
||||
computedDisabled: false
|
||||
})
|
||||
widgets.push(widget)
|
||||
return widget
|
||||
}
|
||||
)
|
||||
}
|
||||
return { node: node as unknown as LGraphNode, widgets }
|
||||
}
|
||||
|
||||
describe('widgets', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSettingGet.mockReturnValue('after')
|
||||
mockNextValueForLinkedTarget.mockReturnValue('next')
|
||||
mockIsComboWidget.mockImplementation(
|
||||
(widget: MockWidget) => widget.type === 'combo'
|
||||
)
|
||||
mockTransformInputSpecV1ToV2.mockImplementation(
|
||||
(_inputData: InputSpec, options: { name: string }) => ({
|
||||
name: options.name
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('updates the control widget label from the configured run mode', () => {
|
||||
const widget = makeTargetWidget()
|
||||
|
||||
mockSettingGet.mockReturnValue('before')
|
||||
updateControlWidgetLabel(widget)
|
||||
expect(widget.label).toBe('translated:g.control_before_generate')
|
||||
|
||||
mockSettingGet.mockReturnValue('after')
|
||||
updateControlWidgetLabel(widget)
|
||||
expect(widget.label).toBe('translated:g.control_after_generate')
|
||||
})
|
||||
|
||||
it('adds control and filter widgets for combo targets', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget({ type: 'combo', computedDisabled: true })
|
||||
|
||||
const result = addValueControlWidgets(node, target, '', undefined, [
|
||||
'COMBO',
|
||||
{
|
||||
control_prefix: 'custom'
|
||||
}
|
||||
] as unknown as InputSpec)
|
||||
|
||||
expect(result).toHaveLength(2)
|
||||
expect(widgets[0].name).toBe('custom control_after_generate')
|
||||
expect(widgets[0].value).toBe('randomize')
|
||||
expect((widgets[0] as IComboWidget).options.values).toContain(
|
||||
'increment-wrap'
|
||||
)
|
||||
expect(widgets[0][IS_CONTROL_WIDGET]).toBe(true)
|
||||
expect(widgets[0].disabled).toBe(true)
|
||||
expect(widgets[1].name).toBe('custom control_filter_list')
|
||||
expect(widgets[1].disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('uses explicit option names and can skip the combo filter widget', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget({ type: 'combo' })
|
||||
|
||||
addValueControlWidgets(
|
||||
node,
|
||||
target,
|
||||
'fixed',
|
||||
{
|
||||
addFilterList: false,
|
||||
controlAfterGenerateName: 'mode'
|
||||
},
|
||||
['COMBO', {}] as unknown as InputSpec
|
||||
)
|
||||
|
||||
expect(widgets).toHaveLength(1)
|
||||
expect(widgets[0].name).toBe('mode')
|
||||
})
|
||||
|
||||
it('applies linked target values after queueing in after mode', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].afterQueued?.({ isPartialExecution: true })
|
||||
|
||||
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith({
|
||||
target,
|
||||
linkedWidgets: target.linkedWidgets,
|
||||
nodeId: 42,
|
||||
isPartialExecution: true
|
||||
})
|
||||
expect(target.value).toBe('next')
|
||||
expect(target.callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
|
||||
it('waits until the second beforeQueued call in before mode', () => {
|
||||
mockSettingGet.mockReturnValue('before')
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].beforeQueued?.()
|
||||
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
|
||||
|
||||
widgets[0].beforeQueued?.({ isPartialExecution: false })
|
||||
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ isPartialExecution: false })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not change the target when the target has a linked input or no next value', () => {
|
||||
const { node, widgets } = makeNode([
|
||||
{ widget: { name: 'seed' }, link: 1 }
|
||||
] as LGraphNode['inputs'])
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].afterQueued?.()
|
||||
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
|
||||
|
||||
const unlinked = makeNode()
|
||||
mockNextValueForLinkedTarget.mockReturnValue(undefined)
|
||||
addValueControlWidgets(unlinked.node, target)
|
||||
unlinked.widgets[0].afterQueued?.()
|
||||
expect(target.callback).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses the legacy single control widget name from input data before widgetName', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
const result = addValueControlWidget(
|
||||
node,
|
||||
target,
|
||||
'fixed',
|
||||
undefined,
|
||||
'fallback',
|
||||
[
|
||||
'INT',
|
||||
{
|
||||
control_after_generate: 'from_input_data'
|
||||
}
|
||||
] as unknown as InputSpec
|
||||
)
|
||||
|
||||
expect(result).toBe(widgets[0])
|
||||
expect(widgets[0].name).toBe('from_input_data')
|
||||
})
|
||||
|
||||
it('exposes transformed widget constructors and type validation', () => {
|
||||
const { node } = makeNode()
|
||||
|
||||
const intWidget = ComfyWidgets.INT(
|
||||
node,
|
||||
'value',
|
||||
['INT', {}] as unknown as InputSpec,
|
||||
{} as never
|
||||
)
|
||||
|
||||
expect(intWidget.widget.name).toBe('INT:value')
|
||||
expect(intWidget.minWidth).toBe(20)
|
||||
expect(intWidget.minHeight).toBe(30)
|
||||
expect(
|
||||
ComfyWidgets.IMAGEUPLOAD(node, 'image', ['IMAGE', {}], {} as never)
|
||||
).toMatchObject({
|
||||
widget: { name: 'IMAGEUPLOAD:image' },
|
||||
minWidth: 5,
|
||||
minHeight: 6
|
||||
})
|
||||
expect(isValidWidgetType('INT')).toBe(true)
|
||||
expect(isValidWidgetType('DYNAMIC')).toBe(true)
|
||||
expect(isValidWidgetType('missing')).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -38,6 +38,12 @@ describe('useAudioService', () => {
|
||||
name: 'test-audio-123.wav'
|
||||
}
|
||||
|
||||
async function freshService() {
|
||||
vi.resetModules()
|
||||
const audioServiceModule = await import('@/services/audioService')
|
||||
return audioServiceModule.useAudioService()
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
@@ -90,12 +96,41 @@ describe('useAudioService', () => {
|
||||
)
|
||||
mockRegister.mockRejectedValueOnce(error)
|
||||
|
||||
await service.registerWavEncoder()
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(mockConnect).toHaveBeenCalledTimes(0)
|
||||
expect(mockRegister).toHaveBeenCalledTimes(0)
|
||||
expect(mockConnect).toHaveBeenCalledTimes(1)
|
||||
expect(mockRegister).toHaveBeenCalledTimes(1)
|
||||
expect(console.error).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should log encoder registration errors', async () => {
|
||||
const error = new Error('Encoder failed')
|
||||
mockRegister.mockRejectedValueOnce(error)
|
||||
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Audio Service Error (encoder):',
|
||||
'Failed to register WAV encoder',
|
||||
error
|
||||
)
|
||||
})
|
||||
|
||||
it('should log non-Error encoder registration failures', async () => {
|
||||
mockRegister.mockRejectedValueOnce('Encoder failed')
|
||||
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Audio Service Error (encoder):',
|
||||
'Failed to register WAV encoder',
|
||||
'Encoder failed'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('stopAllTracks', () => {
|
||||
|
||||
118
src/services/autoQueueService.test.ts
Normal file
118
src/services/autoQueueService.test.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { setupAutoQueueHandler } from '@/services/autoQueueService'
|
||||
|
||||
type ApiEvent = 'graphChanged'
|
||||
type ApiListener = () => void
|
||||
type Subscription = () => Promise<void> | void
|
||||
|
||||
const {
|
||||
listeners,
|
||||
queueCountStore,
|
||||
queueSettingsStore,
|
||||
appState,
|
||||
addEventListener,
|
||||
isInstantRunningMode
|
||||
} = vi.hoisted(() => ({
|
||||
listeners: new Map<ApiEvent, ApiListener>(),
|
||||
queueCountStore: {
|
||||
count: 0,
|
||||
subscription: undefined as Subscription | undefined,
|
||||
$subscribe: vi.fn((_callback: Subscription) => {
|
||||
queueCountStore.subscription = _callback
|
||||
})
|
||||
},
|
||||
queueSettingsStore: {
|
||||
mode: 'manual',
|
||||
batchCount: 1
|
||||
},
|
||||
appState: {
|
||||
lastExecutionError: null as unknown,
|
||||
queuePrompt: vi.fn()
|
||||
},
|
||||
addEventListener: vi.fn((event: ApiEvent, listener: ApiListener) => {
|
||||
listeners.set(event, listener)
|
||||
}),
|
||||
isInstantRunningMode: vi.fn((mode: string) => mode === 'instant')
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: { addEventListener }
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: appState
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/queueStore', () => ({
|
||||
isInstantRunningMode,
|
||||
useQueuePendingTaskCountStore: () => queueCountStore,
|
||||
useQueueSettingsStore: () => queueSettingsStore
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
listeners.clear()
|
||||
queueCountStore.count = 0
|
||||
queueCountStore.subscription = undefined
|
||||
queueCountStore.$subscribe.mockClear()
|
||||
queueSettingsStore.mode = 'manual'
|
||||
queueSettingsStore.batchCount = 1
|
||||
appState.lastExecutionError = null
|
||||
appState.queuePrompt.mockReset().mockResolvedValue(undefined)
|
||||
addEventListener.mockClear()
|
||||
isInstantRunningMode
|
||||
.mockClear()
|
||||
.mockImplementation((mode) => mode === 'instant')
|
||||
})
|
||||
|
||||
describe('setupAutoQueueHandler', () => {
|
||||
it('queues immediately on graph changes when change mode is idle', () => {
|
||||
queueSettingsStore.mode = 'change'
|
||||
queueSettingsStore.batchCount = 3
|
||||
|
||||
setupAutoQueueHandler()
|
||||
listeners.get('graphChanged')?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 3)
|
||||
})
|
||||
|
||||
it('queues after pending work drains in instant mode', async () => {
|
||||
queueSettingsStore.mode = 'instant'
|
||||
queueSettingsStore.batchCount = 2
|
||||
queueCountStore.count = 0
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 2)
|
||||
})
|
||||
|
||||
it('queues after a changed graph drains from an active queue', async () => {
|
||||
queueSettingsStore.mode = 'change'
|
||||
queueCountStore.count = 1
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
listeners.get('graphChanged')?.()
|
||||
expect(appState.queuePrompt).not.toHaveBeenCalled()
|
||||
|
||||
queueCountStore.count = 0
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('does not requeue while work remains or the last run failed', async () => {
|
||||
queueSettingsStore.mode = 'instant'
|
||||
queueCountStore.count = 1
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
appState.lastExecutionError = { message: 'failed' }
|
||||
queueCountStore.count = 0
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
363
src/services/colorPaletteService.test.ts
Normal file
363
src/services/colorPaletteService.test.ts
Normal file
@@ -0,0 +1,363 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DEFAULT_DARK_COLOR_PALETTE } from '@/constants/coreColorPalettes'
|
||||
import {
|
||||
LGraphCanvas,
|
||||
LiteGraph,
|
||||
RenderShape
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { CompletedPalette, Palette } from '@/schemas/colorPaletteSchema'
|
||||
|
||||
const mockCanvas = vi.hoisted(() => ({
|
||||
default_connection_color_byType: {} as Record<string, string>,
|
||||
node_title_color: '',
|
||||
default_link_color: '',
|
||||
background_image: '',
|
||||
clear_background_color: '',
|
||||
_pattern: 'pattern' as string | undefined,
|
||||
setDirty: vi.fn()
|
||||
}))
|
||||
|
||||
const mockColorPaletteStore = vi.hoisted(() => ({
|
||||
customPalettes: {} as Record<string, unknown>,
|
||||
palettesLookup: {} as Record<string, unknown>,
|
||||
completedActivePalette: undefined as unknown,
|
||||
activePaletteId: 'dark',
|
||||
addCustomPalette: vi.fn(),
|
||||
deleteCustomPalette: vi.fn(),
|
||||
completePalette: vi.fn()
|
||||
}))
|
||||
|
||||
const mockSettingStore = vi.hoisted(() => ({
|
||||
get: vi.fn(),
|
||||
set: vi.fn()
|
||||
}))
|
||||
|
||||
const mockNodeDefStore = vi.hoisted(() => ({
|
||||
nodeDataTypes: new Set(['IMAGE', 'MISSING'])
|
||||
}))
|
||||
|
||||
const mockDownloadBlob = vi.hoisted(() => vi.fn())
|
||||
const mockUploadFile = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: { canvas: mockCanvas }
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/workspace/colorPaletteStore', () => ({
|
||||
useColorPaletteStore: () => mockColorPaletteStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => mockSettingStore
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/nodeDefStore', () => ({
|
||||
useNodeDefStore: () => mockNodeDefStore
|
||||
}))
|
||||
|
||||
vi.mock('@/base/common/downloadUtil', () => ({
|
||||
downloadBlob: mockDownloadBlob
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/utils', () => ({
|
||||
uploadFile: mockUploadFile
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useErrorHandling', () => ({
|
||||
useErrorHandling: () => ({
|
||||
wrapWithErrorHandling: <T>(action: T) => action,
|
||||
wrapWithErrorHandlingAsync: <T>(action: T) => action
|
||||
})
|
||||
}))
|
||||
|
||||
import { useColorPaletteService } from './colorPaletteService'
|
||||
|
||||
const validCustomPalette = {
|
||||
id: 'custom',
|
||||
name: 'Custom',
|
||||
colors: {
|
||||
node_slot: {},
|
||||
litegraph_base: {},
|
||||
comfy_base: {}
|
||||
}
|
||||
} satisfies Palette
|
||||
|
||||
function makeCompletedPalette(id = 'custom'): CompletedPalette {
|
||||
const palette = structuredClone(
|
||||
DEFAULT_DARK_COLOR_PALETTE
|
||||
) as CompletedPalette
|
||||
palette.id = id
|
||||
palette.name = 'Custom'
|
||||
palette.colors.node_slot.IMAGE = '#123456'
|
||||
palette.colors.litegraph_base.NODE_TITLE_COLOR = '#abcdef'
|
||||
palette.colors.litegraph_base.LINK_COLOR = '#fedcba'
|
||||
palette.colors.litegraph_base.BACKGROUND_IMAGE = 'grid.png'
|
||||
palette.colors.litegraph_base.CLEAR_BACKGROUND_COLOR = '#010203'
|
||||
palette.colors.litegraph_base.NODE_DEFAULT_SHAPE = 'legacy'
|
||||
palette.colors.comfy_base['fg-color'] = '#111111'
|
||||
palette.colors.comfy_base['bg-color'] = '#222222'
|
||||
delete palette.colors.comfy_base['contrast-mix-color']
|
||||
return palette
|
||||
}
|
||||
|
||||
describe('useColorPaletteService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCanvas.default_connection_color_byType = {}
|
||||
mockCanvas.node_title_color = ''
|
||||
mockCanvas.default_link_color = ''
|
||||
mockCanvas.background_image = ''
|
||||
mockCanvas.clear_background_color = ''
|
||||
mockCanvas._pattern = 'pattern'
|
||||
LGraphCanvas.link_type_colors = {}
|
||||
mockSettingStore.get.mockReturnValue('')
|
||||
mockSettingStore.set.mockResolvedValue(undefined)
|
||||
mockColorPaletteStore.customPalettes = { custom: validCustomPalette }
|
||||
mockColorPaletteStore.palettesLookup = { custom: validCustomPalette }
|
||||
mockColorPaletteStore.completedActivePalette = makeCompletedPalette()
|
||||
mockColorPaletteStore.activePaletteId = 'dark'
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
makeCompletedPalette()
|
||||
)
|
||||
document.documentElement.style.cssText = ''
|
||||
document.documentElement.style.setProperty(
|
||||
'--color-datatype-MISSING',
|
||||
'#ffffff'
|
||||
)
|
||||
})
|
||||
|
||||
it('adds valid custom palettes and persists the custom palette map', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.addCustomColorPalette(validCustomPalette)
|
||||
|
||||
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
|
||||
validCustomPalette
|
||||
)
|
||||
expect(mockSettingStore.set).toHaveBeenCalledWith(
|
||||
'Comfy.CustomColorPalettes',
|
||||
mockColorPaletteStore.customPalettes
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects invalid custom palettes before mutating the store', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await expect(service.addCustomColorPalette({} as Palette)).rejects.toThrow(
|
||||
'Invalid color palette against zod schema'
|
||||
)
|
||||
expect(mockColorPaletteStore.addCustomPalette).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('deletes custom palettes and persists the custom palette map', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.deleteCustomColorPalette('custom')
|
||||
|
||||
expect(mockColorPaletteStore.deleteCustomPalette).toHaveBeenCalledWith(
|
||||
'custom'
|
||||
)
|
||||
expect(mockSettingStore.set).toHaveBeenCalledWith(
|
||||
'Comfy.CustomColorPalettes',
|
||||
mockColorPaletteStore.customPalettes
|
||||
)
|
||||
})
|
||||
|
||||
it('loads palette colors into litegraph, Vue CSS variables, and canvas state', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.default_connection_color_byType.IMAGE).toBe('#123456')
|
||||
expect(LGraphCanvas.link_type_colors.IMAGE).toBe('#123456')
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--color-datatype-IMAGE')
|
||||
).toBe('#123456')
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue(
|
||||
'--color-datatype-MISSING'
|
||||
)
|
||||
).toBe('')
|
||||
expect(mockCanvas.node_title_color).toBe('#abcdef')
|
||||
expect(mockCanvas.default_link_color).toBe('#fedcba')
|
||||
expect(mockCanvas.background_image).toBe('grid.png')
|
||||
expect(mockCanvas.clear_background_color).toBe('#010203')
|
||||
expect(mockCanvas._pattern).toBeUndefined()
|
||||
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.ROUND)
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
`litegraph_base.NODE_DEFAULT_SHAPE only accepts [${[
|
||||
RenderShape.BOX,
|
||||
RenderShape.ROUND,
|
||||
RenderShape.CARD
|
||||
].join(', ')}] but got legacy`
|
||||
)
|
||||
expect(document.documentElement.style.getPropertyValue('--fg-color')).toBe(
|
||||
'#111111'
|
||||
)
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--contrast-mix-color')
|
||||
).toBe('var(--palette-contrast-mix-color)')
|
||||
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
|
||||
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
|
||||
})
|
||||
|
||||
it('skips absent palette sections while still activating the palette', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
fromAny<CompletedPalette, unknown>({
|
||||
...completedPalette,
|
||||
colors: {
|
||||
node_slot: undefined,
|
||||
litegraph_base: completedPalette.colors.litegraph_base,
|
||||
comfy_base: undefined
|
||||
}
|
||||
})
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.node_title_color).toBe('#abcdef')
|
||||
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
|
||||
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
|
||||
})
|
||||
|
||||
it('removes Vue node theme overrides for built-in palettes', async () => {
|
||||
mockColorPaletteStore.palettesLookup = { dark: validCustomPalette }
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
makeCompletedPalette('dark')
|
||||
)
|
||||
document.documentElement.style.setProperty(
|
||||
'--component-node-border',
|
||||
'#ffffff'
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('dark')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--component-node-border')
|
||||
).toBe('')
|
||||
})
|
||||
|
||||
it('removes Vue node theme variables when completed palette values are absent', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
// NODE_BOX_OUTLINE_COLOR is required on the completed palette type; the
|
||||
// test needs it absent, so delete via Reflect to keep the type intact.
|
||||
Reflect.deleteProperty(
|
||||
completedPalette.colors.litegraph_base,
|
||||
'NODE_BOX_OUTLINE_COLOR'
|
||||
)
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
document.documentElement.style.setProperty(
|
||||
'--component-node-border',
|
||||
'#ffffff'
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--component-node-border')
|
||||
).toBe('')
|
||||
})
|
||||
|
||||
it('preserves numeric LiteGraph node shapes without warning', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.colors.litegraph_base.NODE_DEFAULT_SHAPE = RenderShape.CARD
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.CARD)
|
||||
expect(warn).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses explicit optional comfy color values when present', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.colors.comfy_base['contrast-mix-color'] = '#333333'
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--contrast-mix-color')
|
||||
).toBe('#333333')
|
||||
})
|
||||
|
||||
it('uses a white splash background for light themes', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.light_theme = true
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(localStorage.getItem('comfy-splash-bg')).toBe('#FFFFFF')
|
||||
expect(localStorage.getItem('comfy-splash-fg')).toBe('#111111')
|
||||
})
|
||||
|
||||
it('uses transparent canvas background and bg image CSS when a background image setting exists', async () => {
|
||||
mockSettingStore.get.mockReturnValue('/custom/background.png')
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.clear_background_color).toBe('transparent')
|
||||
expect(document.documentElement.style.getPropertyValue('--bg-img')).toBe(
|
||||
"url('/custom/background.png')"
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when loading or exporting an unknown palette', async () => {
|
||||
mockColorPaletteStore.palettesLookup = {}
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await expect(service.loadColorPalette('missing')).rejects.toThrow(
|
||||
'Color palette missing not found'
|
||||
)
|
||||
expect(() => service.exportColorPalette('missing')).toThrow(
|
||||
'Color palette missing not found'
|
||||
)
|
||||
})
|
||||
|
||||
it('exports palette JSON by id', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
service.exportColorPalette('custom')
|
||||
|
||||
expect(mockDownloadBlob).toHaveBeenCalledOnce()
|
||||
const [filename, blob] = mockDownloadBlob.mock.calls[0] as [string, Blob]
|
||||
expect(filename).toBe('custom.json')
|
||||
await expect(blob.text()).resolves.toContain('"id": "custom"')
|
||||
})
|
||||
|
||||
it('imports palette JSON through the custom palette path', async () => {
|
||||
mockUploadFile.mockResolvedValue({
|
||||
text: () => Promise.resolve(JSON.stringify(validCustomPalette))
|
||||
})
|
||||
const service = useColorPaletteService()
|
||||
|
||||
const palette = await service.importColorPalette()
|
||||
|
||||
expect(mockUploadFile).toHaveBeenCalledWith('application/json')
|
||||
expect(palette).toEqual(validCustomPalette)
|
||||
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
|
||||
validCustomPalette
|
||||
)
|
||||
})
|
||||
|
||||
it('returns the completed active palette from the store', () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
expect(service.getActiveColorPalette()).toBe(
|
||||
mockColorPaletteStore.completedActivePalette
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -18,7 +18,7 @@ import type {
|
||||
} from '@/stores/dialogStore'
|
||||
|
||||
import type { ComponentAttrs } from 'vue-component-type-helpers'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
// Lazy loaders for dialogs - components are loaded on first use
|
||||
@@ -442,9 +442,9 @@ export const useDialogService = () => {
|
||||
})
|
||||
}
|
||||
|
||||
async function showSubscriptionRequiredDialog(options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) {
|
||||
async function showSubscriptionRequiredDialog(
|
||||
options?: SubscriptionDialogOptions
|
||||
) {
|
||||
if (!isCloud || !window.__CONFIG__?.subscription_required) {
|
||||
return
|
||||
}
|
||||
|
||||
324
src/services/extensionService.test.ts
Normal file
324
src/services/extensionService.test.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { AuthUserInfo } from '@/types/authTypes'
|
||||
import type { ComfyExtension } from '@/types/comfy'
|
||||
import { useExtensionService } from './extensionService'
|
||||
|
||||
const mockLoadDisabledExtensionNames = vi.hoisted(() => vi.fn())
|
||||
const mockRegisterExtension = vi.hoisted(() => vi.fn())
|
||||
const mockCaptureCoreExtensions = vi.hoisted(() => vi.fn())
|
||||
const mockEnabledExtensions = vi.hoisted(() => ({
|
||||
value: [] as ComfyExtension[]
|
||||
}))
|
||||
vi.mock('@/stores/extensionStore', () => ({
|
||||
useExtensionStore: () => ({
|
||||
loadDisabledExtensionNames: mockLoadDisabledExtensionNames,
|
||||
registerExtension: mockRegisterExtension,
|
||||
captureCoreExtensions: mockCaptureCoreExtensions,
|
||||
get enabledExtensions() {
|
||||
return mockEnabledExtensions.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const mockGetSetting = vi.hoisted(() => vi.fn())
|
||||
const mockAddSetting = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: mockGetSetting,
|
||||
addSetting: mockAddSetting
|
||||
})
|
||||
}))
|
||||
|
||||
const mockAddDefaultKeybinding = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/keybindings/keybindingStore', () => ({
|
||||
useKeybindingStore: () => ({
|
||||
addDefaultKeybinding: mockAddDefaultKeybinding
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/keybindings/keybinding', () => ({
|
||||
KeybindingImpl: class KeybindingImpl {
|
||||
constructor(readonly source: unknown) {}
|
||||
}
|
||||
}))
|
||||
|
||||
const mockLoadExtensionCommands = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({
|
||||
loadExtensionCommands: mockLoadExtensionCommands
|
||||
})
|
||||
}))
|
||||
|
||||
const mockLoadExtensionMenuCommands = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/menuItemStore', () => ({
|
||||
useMenuItemStore: () => ({
|
||||
loadExtensionMenuCommands: mockLoadExtensionMenuCommands
|
||||
})
|
||||
}))
|
||||
|
||||
const mockRegisterBottomPanelTabs = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/workspace/bottomPanelStore', () => ({
|
||||
useBottomPanelStore: () => ({
|
||||
registerExtensionBottomPanelTabs: mockRegisterBottomPanelTabs
|
||||
})
|
||||
}))
|
||||
|
||||
const mockRegisterCustomWidgets = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => ({
|
||||
registerCustomWidgets: mockRegisterCustomWidgets
|
||||
})
|
||||
}))
|
||||
|
||||
const mockToastErrorHandler = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/composables/useErrorHandling', () => ({
|
||||
useErrorHandling: () => ({
|
||||
wrapWithErrorHandling:
|
||||
<Args extends unknown[], Return>(fn: (...args: Args) => Return) =>
|
||||
(...args: Args) =>
|
||||
fn(...args),
|
||||
wrapWithErrorHandlingAsync:
|
||||
<Args extends unknown[], Return>(
|
||||
fn: (...args: Args) => Return | Promise<Return>,
|
||||
handler: (error: unknown) => void
|
||||
) =>
|
||||
async (...args: Args) => {
|
||||
try {
|
||||
return await fn(...args)
|
||||
} catch (error) {
|
||||
handler(error)
|
||||
}
|
||||
},
|
||||
toastErrorHandler: mockToastErrorHandler
|
||||
})
|
||||
}))
|
||||
|
||||
const mockUserResolvedCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<(user: AuthUserInfo) => void>
|
||||
}))
|
||||
const mockTokenRefreshedCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<() => void>
|
||||
}))
|
||||
const mockUserLogoutCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<() => void>
|
||||
}))
|
||||
vi.mock('@/composables/auth/useCurrentUser', () => ({
|
||||
useCurrentUser: () => ({
|
||||
onUserResolved: (callback: (user: AuthUserInfo) => void) => {
|
||||
mockUserResolvedCallbacks.values.push(callback)
|
||||
},
|
||||
onTokenRefreshed: (callback: () => void) => {
|
||||
mockTokenRefreshedCallbacks.values.push(callback)
|
||||
},
|
||||
onUserLogout: (callback: () => void) => {
|
||||
mockUserLogoutCallbacks.values.push(callback)
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const mockSetCurrentExtension = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/lib/litegraph/src/contextMenuCompat', () => ({
|
||||
legacyMenuCompat: {
|
||||
setCurrentExtension: mockSetCurrentExtension
|
||||
}
|
||||
}))
|
||||
|
||||
const mockApp = vi.hoisted(() => ({ value: { name: 'app' } }))
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: mockApp.value
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
getExtensions: vi.fn(),
|
||||
fileURL: vi.fn((path: string) => path)
|
||||
}
|
||||
}))
|
||||
|
||||
describe('useExtensionService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockEnabledExtensions.value = []
|
||||
mockUserResolvedCallbacks.values = []
|
||||
mockTokenRefreshedCallbacks.values = []
|
||||
mockUserLogoutCallbacks.values = []
|
||||
})
|
||||
|
||||
it('registers extension contributions across stores', async () => {
|
||||
const widgets = { CustomWidget: vi.fn() }
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'registration-extension',
|
||||
keybindings: [{ commandId: 'command.one', combo: { key: 'K' } }],
|
||||
commands: [{ id: 'command.one', label: 'Command One' }],
|
||||
menuCommands: [{ path: ['File'], commands: ['command.one'] }],
|
||||
settings: [{ id: 'setting.one', name: 'Setting One' }],
|
||||
bottomPanelTabs: [{ id: 'tab.one', title: 'Tab One' }],
|
||||
getCustomWidgets: vi.fn().mockResolvedValue(widgets)
|
||||
})
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
|
||||
expect(mockRegisterExtension).toHaveBeenCalledWith(extension)
|
||||
expect(mockAddDefaultKeybinding).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
source: { commandId: 'command.one', combo: { key: 'K' } }
|
||||
})
|
||||
)
|
||||
expect(mockLoadExtensionCommands).toHaveBeenCalledWith(extension)
|
||||
expect(mockLoadExtensionMenuCommands).toHaveBeenCalledWith(extension)
|
||||
expect(mockAddSetting.mock.calls[0][0]).toEqual({
|
||||
id: 'setting.one',
|
||||
name: 'Setting One'
|
||||
})
|
||||
expect(mockRegisterBottomPanelTabs).toHaveBeenCalledWith(extension)
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRegisterCustomWidgets).toHaveBeenCalledWith(widgets)
|
||||
})
|
||||
})
|
||||
|
||||
it('invokes auth lifecycle hooks through registered callbacks', async () => {
|
||||
const onAuthUserResolved = vi.fn()
|
||||
const onAuthTokenRefreshed = vi.fn()
|
||||
const onAuthUserLogout = vi.fn()
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'auth-extension',
|
||||
onAuthUserResolved,
|
||||
onAuthTokenRefreshed,
|
||||
onAuthUserLogout
|
||||
})
|
||||
const user = fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
mockUserResolvedCallbacks.values[0](user)
|
||||
mockTokenRefreshedCallbacks.values[0]()
|
||||
mockUserLogoutCallbacks.values[0]()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAuthUserResolved).toHaveBeenCalledWith(user, mockApp.value)
|
||||
expect(onAuthTokenRefreshed).toHaveBeenCalled()
|
||||
expect(onAuthUserLogout).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('reports auth hook errors through the toast handler', async () => {
|
||||
const error = new Error('auth failed')
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-auth-extension',
|
||||
onAuthUserResolved: vi.fn(() => {
|
||||
throw error
|
||||
})
|
||||
})
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
mockUserResolvedCallbacks.values[0](
|
||||
fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastErrorHandler).toHaveBeenCalledWith(error)
|
||||
})
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
'[Extension Auth Hook Error]',
|
||||
expect.objectContaining({
|
||||
extension: 'failing-auth-extension',
|
||||
hook: 'onAuthUserResolved',
|
||||
error
|
||||
})
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('invokes synchronous extension methods and keeps failures isolated', () => {
|
||||
const getSelectionToolboxCommands = vi.fn(() => ['command.one'])
|
||||
const failingGetSelectionToolboxCommands = vi.fn(() => {
|
||||
throw new Error('menu failed')
|
||||
})
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockEnabledExtensions.value = [
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'working-extension',
|
||||
getSelectionToolboxCommands
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'non-function-extension',
|
||||
getSelectionToolboxCommands: ['not callable']
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-extension',
|
||||
getSelectionToolboxCommands: failingGetSelectionToolboxCommands
|
||||
}),
|
||||
{ name: 'missing-method-extension' }
|
||||
]
|
||||
const service = useExtensionService()
|
||||
|
||||
const results = service.invokeExtensions(
|
||||
'getSelectionToolboxCommands',
|
||||
fromAny<LGraphNode, unknown>({ id: 1 })
|
||||
)
|
||||
|
||||
expect(results).toEqual([['command.one']])
|
||||
expect(getSelectionToolboxCommands).toHaveBeenCalledWith(
|
||||
fromAny<LGraphNode, unknown>({ id: 1 }),
|
||||
mockApp.value
|
||||
)
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Error calling extension 'failing-extension' method 'getSelectionToolboxCommands'",
|
||||
expect.objectContaining({ error: expect.any(Error) }),
|
||||
expect.objectContaining({
|
||||
extension: expect.objectContaining({ name: 'failing-extension' })
|
||||
}),
|
||||
expect.objectContaining({
|
||||
args: [fromAny<LGraphNode, unknown>({ id: 1 })]
|
||||
})
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('tracks current extension around async setup callbacks', async () => {
|
||||
const setup = vi.fn().mockResolvedValue('setup-result')
|
||||
const failingSetup = vi.fn().mockRejectedValue(new Error('setup failed'))
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockEnabledExtensions.value = [
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'setup-extension',
|
||||
setup
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'non-function-extension',
|
||||
setup: true
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-setup-extension',
|
||||
setup: failingSetup
|
||||
}),
|
||||
{ name: 'missing-method-extension' }
|
||||
]
|
||||
const service = useExtensionService()
|
||||
|
||||
const results = await service.invokeExtensionsAsync('setup')
|
||||
|
||||
expect(results).toEqual(['setup-result', undefined, undefined, undefined])
|
||||
expect(mockSetCurrentExtension.mock.calls.map((call) => call[0])).toEqual([
|
||||
'setup-extension',
|
||||
'failing-setup-extension',
|
||||
null,
|
||||
null
|
||||
])
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Error calling extension 'failing-setup-extension' method 'setup'",
|
||||
expect.objectContaining({ error: expect.any(Error) }),
|
||||
expect.objectContaining({
|
||||
extension: expect.objectContaining({ name: 'failing-setup-extension' })
|
||||
}),
|
||||
expect.objectContaining({ args: [] })
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user