Compare commits

...

9 Commits

Author SHA1 Message Date
huang47
7141bda563 test: consolidate coverage suites and clean up superseded test scaffolding
Fold sibling suites back into their canonical files (app.test.ts,
litegraphService.test.ts), replace module-level global.fetch/global.URL
mocks with vi.stubGlobal, apply helper refactors (avif/gltf box
builders, searchAndReplace createGraph, treeUtil createNode hoist), and
drop tests superseded by rewritten variants.
2026-07-02 16:36:20 -07:00
huang47
74d4366994 test: cover the root-graph branch in widget promotion options test 2026-07-02 16:35:51 -07:00
huang47
c85c0585ab test: address review findings on new coverage tests
Exercise the real drain path in autoQueueService test, align
executionErrorStore no-op test with its fixture, reset ComfyApp
clipspace state between litegraphService tests, restore app.canvas and
global stubs in finally blocks, strengthen subgraphStore compact-detail
assertion, fix imprecise imageUtil test title.
2026-07-02 15:53:21 -07:00
huang47
804630fa02 test: restructure coverage additions to be purely additive vs main
Restore all pre-existing test lines; new coverage lands as pure
additions. Suites whose mock architecture conflicts with the original
files move to sibling files (app.core.test.ts,
litegraphService.core.test.ts). Rewrites/refactors of existing tests
and helpers are deferred to a follow-up PR.
2026-07-02 14:20:13 -07:00
huang47
7417864353 fix(tests): resolve typecheck errors in coverage tests
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-07-02 11:27:13 -07:00
huang47
1f26dc2b57 test: cover critical missed branches 2026-07-02 02:23:01 -07:00
Benjamin Lu
2ec2a0e091 feat: attribute payment intent through paywall, checkout, and top-up telemetry (#13363)
## Summary

Answers "why did this user want to pay?" by capturing the triggering
product moment at every paywall/upsell entry point and carrying it
through checkout and success telemetry.

## Changes

- **What**:
- Widen `SubscriptionDialogReason` from 4 coarse values to 13 grounded
intent sources (`subscribe_to_run`, `upgrade_to_add_credits`,
`invite_member_upsell`, `settings_billing_panel`, etc.)
- Fire `app:subscription_required_modal_opened` from
`useSubscriptionDialog` (the choke point all dialog variants pass
through) — the workspace/unified path previously emitted nothing; remove
the now-duplicate emitters in `useSubscription` and
`usePricingTableUrlLoader`
- Add `payment_intent_source` to
`BeginCheckoutMetadata`/`SubscriptionSuccessMetadata`, threaded via the
existing `reason` prop: dialog → `PricingTable` →
`performSubscriptionCheckout` → pending-attempt record, so legacy
`app:monthly_subscription_succeeded` carries intent alongside
`checkout_attempt_id`
- Fire `begin_checkout` on the workspace checkout path
(`useSubscriptionCheckout`, personal + team confirm) and the team
deep-link util — both previously emitted nothing; `tier` widened to
`TierKey | 'team'`
- Implement `trackBeginCheckout` in `PostHogTelemetryProvider` (was
GTM/host-only, so `begin_checkout` never reached PostHog)
- Thread `showSubscriptionDialog(options)` through the billing-context
adapters and pass a reason at ~14 call sites; add `source` to
`app:add_api_credit_button_clicked`

## Review Focus

- `modal_opened` now fires once per dialog actually shown, so a
free-tier user clicking Upgrade emits two events (free-tier dialog, then
pricing table) where the legacy path emitted one
- Intent is threaded explicitly via props/params rather than shared
state; `useSubscriptionCheckout` gained an optional second parameter
2026-07-02 03:11:21 +00:00
Mobeen Abdullah
9cf5c9a93f refactor(website): tidy customer story review nits (#13324)
## Summary

Small follow-up to #13289 applying two non-blocking review nits from
Alex's review.

## Changes

- **What**: drop the redundant `before:content-['']` on the
customer-story list bullet (Tailwind emits the empty `content`
automatically once another `before:` utility is present), and rename
`HEADER_OFFSET` to `HEADER_OFFSET_PX` in `ArticleNav` so the scroll
constants use consistent unit suffixes.

## Review Focus

Both changes are cosmetic with no behavior change. Confirmed in the
browser that the list bullet still renders identically (6px yellow dot)
without the explicit `content` utility.

## Notes from the #13289 review (left as-is here, open to discussion)

Three other comments from the review are intentionally not changed in
this PR; reasoning below so the decisions are on record:

- **`Category` type in `ArticleNav`**: kept the `ComponentProps<typeof
CategoryNav>` derivation. AGENTS.md says to derive component types via
`vue-component-type-helpers` rather than redefining them, so the current
form follows the styleguide. Happy to switch to a plain named type if
preferred.
- **Section ids in frontmatter vs the body `<Section>`**: kept the
`customers.content.test.ts` parity test. The short TOC labels live only
in frontmatter and Astro can't introspect the rendered MDX body to build
the nav, so the frontmatter `sections` list and the body anchor ids
can't be trivially deduplicated. A real fix would need a remark plugin
(larger, separate change). The test guards against silent drift in the
meantime.
- **`nextStory` throw**: left as a fail-loud, build-time invariant. The
slug always comes from the same `getStaticPaths` collection, so the
throw is effectively unreachable; it surfaces a future-refactor bug
loudly instead of linking to the wrong story.
2026-07-01 12:45:24 +00:00
jaeone94
9e5fb67b76 Show app mode run validation warning (#12557)
## Summary
Adds an app mode validation warning so users can see when a workflow has
errors before running and jump directly back to graph mode to review
them.

## Changes
- **What**: Adds a reusable app mode warning banner above the Run button
when the execution error store reports workflow errors, including
validation and missing asset states.
- **What**: Reuses the existing graph-error navigation flow so the
warning action switches out of app mode and opens the Errors panel in
graph mode.
- **What**: Updates the app mode Run button icon and accessible label in
the warning state while keeping the Run action non-blocking.
- **What**: Adds unit coverage for the warning render/accessibility
state and an E2E flow that triggers a validation failure, dismisses the
overlay, and opens graph errors from the app mode warning.
- **Breaking**: None.
- **Dependencies**: None.

## Review Focus
The warning intentionally mirrors graph mode behavior: it surfaces the
error state but does not prevent the user from clicking Run. This avoids
turning display-level validation signals into hard execution blockers.

The warning is driven by the existing `hasAnyError` aggregate, so
missing nodes, missing models, and missing media are included alongside
prompt/node/execution errors.

## Tests
- `pnpm format`
- `pnpm lint`
- `pnpm typecheck`
- `pnpm test:unit`
- `pnpm knip`
- `pnpm test:browser:local
browser_tests/tests/appModeValidationWarning.spec.ts`

## Screenshots

<img width="461" height="994" alt="스크린샷 2026-06-25 오후 7 00 55"
src="https://github.com/user-attachments/assets/f8fc20bf-d572-46b5-9fa4-312e7c4c8076"
/>
2026-07-01 15:24:45 +09:00
132 changed files with 17481 additions and 410 deletions

View File

@@ -15,7 +15,7 @@ const { categories } = defineProps<{
const activeSection = ref(categories[0]?.value ?? '')
const HEADER_OFFSET = -144
const HEADER_OFFSET_PX = -144
const BOTTOM_THRESHOLD_PX = 4
const SCROLL_SAFETY_MS = 1500
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
const el = document.getElementById(id)
if (el) {
scrollTo(el, {
offset: HEADER_OFFSET,
offset: HEADER_OFFSET_PX,
duration: 0.8,
immediate: prefersReducedMotion(),
onComplete: clearScrollLock

View File

@@ -1,5 +1,5 @@
<li
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
>
<slot />
</li>

View File

@@ -0,0 +1,45 @@
{
"last_node_id": 9,
"last_link_id": 9,
"nodes": [
{
"id": 9,
"type": "SaveImage",
"pos": {
"0": 64,
"1": 104
},
"size": {
"0": 210,
"1": 58
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": null
}
],
"outputs": [],
"properties": {},
"widgets_values": ["ComfyUI"]
}
],
"links": [],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1,
"offset": [0, 0]
},
"linearData": {
"inputs": [],
"outputs": ["9"]
}
},
"version": 0.4
}

View File

@@ -34,6 +34,10 @@ export class AppModeHelper {
public readonly outputPlaceholder: Locator
/** The linear-mode widget list container (visible in app mode). */
public readonly linearWidgets: Locator
/** The validation warning shown above the app mode run button. */
public readonly validationWarning: Locator
/** The action that opens graph mode errors from the validation warning. */
public readonly viewErrorsInGraphButton: Locator
/** The PrimeVue Popover for the image picker (renders with role="dialog"). */
public readonly imagePickerPopover: Locator
/** The Run button in the app mode footer. */
@@ -92,13 +96,19 @@ export class AppModeHelper {
this.outputPlaceholder = this.page.getByTestId(
TestIds.builder.outputPlaceholder
)
this.linearWidgets = this.page.getByTestId('linear-widgets')
this.linearWidgets = this.page.getByTestId(TestIds.linear.widgetContainer)
this.validationWarning = this.page.getByTestId(
TestIds.linear.validationWarning
)
this.viewErrorsInGraphButton = this.validationWarning.getByTestId(
TestIds.linear.viewErrorsInGraph
)
this.imagePickerPopover = this.page
.getByRole('dialog')
.filter({ has: this.page.getByRole('button', { name: 'All' }) })
.first()
this.runButton = this.page
.getByTestId('linear-run-button')
.getByTestId(TestIds.linear.runButton)
.getByRole('button', { name: /run/i })
this.welcome = this.page.getByTestId(TestIds.appMode.welcome)
this.emptyWorkflowText = this.page.getByTestId(

View File

@@ -172,6 +172,9 @@ export const TestIds = {
mobileNavigation: 'linear-mobile-navigation',
mobileWorkflows: 'linear-mobile-workflows',
outputInfo: 'linear-output-info',
runButton: 'linear-run-button',
validationWarning: 'linear-validation-warning',
viewErrorsInGraph: 'linear-view-errors',
widgetContainer: 'linear-widgets'
},
builder: {

View File

@@ -0,0 +1,106 @@
import {
comfyExpect as expect,
comfyPageFixture as test
} from '@e2e/fixtures/ComfyPage'
import type { NodeError, PromptResponse } from '@/schemas/apiSchema'
import { ExecutionHelper } from '@e2e/fixtures/helpers/ExecutionHelper'
import { enableErrorsOverlay } from '@e2e/fixtures/helpers/ErrorsTabHelper'
import { TestIds } from '@e2e/fixtures/selectors'
const SAVE_IMAGE_NODE_ID = '9'
function buildSaveImageRequiredInputError(): NodeError {
return {
class_type: 'SaveImage',
dependent_outputs: [],
errors: [
{
type: 'required_input_missing',
message: 'Required input is missing: images',
details: '',
extra_info: { input_name: 'images' }
}
]
}
}
test.describe(
'App mode validation warning',
{ tag: ['@ui', '@workflow'] },
() => {
test.beforeEach(async ({ comfyPage }) => {
await enableErrorsOverlay(comfyPage)
await comfyPage.workflow.loadWorkflow('linear-validation-warning')
await comfyPage.appMode.toggleAppMode()
await expect(comfyPage.appMode.linearWidgets).toBeVisible()
})
test('opens graph errors from the app mode validation warning', async ({
comfyPage
}) => {
await expect(comfyPage.appMode.validationWarning).toBeHidden()
const exec = new ExecutionHelper(comfyPage)
await exec.mockValidationFailure({
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
})
await comfyPage.appMode.runButton.click()
const appModeOverlay = comfyPage.appMode.centerPanel.getByTestId(
TestIds.dialogs.errorOverlay
)
await expect(appModeOverlay).toBeHidden()
await expect(comfyPage.appMode.validationWarning).toBeVisible()
await expect(comfyPage.appMode.validationWarning).toContainText(
/Required input missing/i
)
await expect(comfyPage.appMode.viewErrorsInGraphButton).toBeVisible()
await comfyPage.appMode.viewErrorsInGraphButton.click()
await expect(comfyPage.appMode.linearWidgets).toBeHidden()
await expect(
comfyPage.page.getByTestId(TestIds.propertiesPanel.root)
).toBeVisible()
await expect(
comfyPage.page.getByTestId(TestIds.propertiesPanel.errorsTab)
).toBeVisible()
})
test('keeps the app mode run button enabled when the warning is visible', async ({
comfyPage
}) => {
const exec = new ExecutionHelper(comfyPage)
await exec.mockValidationFailure({
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
})
await comfyPage.appMode.runButton.click()
await expect(comfyPage.appMode.validationWarning).toBeVisible()
await expect(comfyPage.appMode.runButton).toBeEnabled()
let promptQueued = false
const mockResponse: PromptResponse = {
prompt_id: 'test-id',
node_errors: {},
error: ''
}
await comfyPage.page.route(
'**/api/prompt',
async (route) => {
promptQueued = true
await route.fulfill({
status: 200,
body: JSON.stringify(mockResponse)
})
},
{ times: 1 }
)
await comfyPage.appMode.runButton.click()
await expect.poll(() => promptQueued).toBe(true)
})
}
)

View File

@@ -1,5 +1,6 @@
import { expect } from '@playwright/test'
import { toLinkId } from '@/types/linkId'
import { toNodeId } from '@/types/nodeId'
import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage'
@@ -15,9 +16,10 @@ test.describe('Graph', { tag: ['@smoke', '@canvas'] }, () => {
await comfyPage.workflow.loadWorkflow('inputs/input_order_swap')
await expect
.poll(() =>
comfyPage.page.evaluate(() => {
return window.app!.graph!.links.get(1)?.target_slot
})
comfyPage.page.evaluate(
(linkId) => window.app!.graph!.links.get(linkId)?.target_slot,
toLinkId(1)
)
)
.toBe(1)
})

View File

@@ -3,6 +3,7 @@ import {
comfyPageFixture as test,
comfyExpect as expect
} from '@e2e/fixtures/ComfyPage'
import { TestIds } from '@e2e/fixtures/selectors'
test.describe('Linear Mode', { tag: '@ui' }, () => {
test('Displays linear controls when app mode active', async ({
@@ -16,7 +17,9 @@ test.describe('Linear Mode', { tag: '@ui' }, () => {
test('Run button visible in linear mode', async ({ comfyPage }) => {
await comfyPage.appMode.enterAppModeWithInputs([])
await expect(comfyPage.page.getByTestId('linear-run-button')).toBeVisible()
await expect(
comfyPage.page.getByTestId(TestIds.linear.runButton)
).toBeVisible()
})
test('Workflow info section visible', async ({ comfyPage }) => {

View File

@@ -0,0 +1,75 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
describe('runWhenGlobalIdle', () => {
beforeEach(() => {
vi.resetModules()
})
afterEach(() => {
vi.useRealTimers()
vi.unstubAllGlobals()
})
it('falls back to a timeout when idle callbacks are unavailable', async () => {
vi.useFakeTimers()
vi.stubGlobal('requestIdleCallback', undefined)
vi.stubGlobal('cancelIdleCallback', undefined)
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
const disposable = runWhenGlobalIdle(runner)
await vi.runAllTimersAsync()
expect(runner).toHaveBeenCalledOnce()
const deadline = runner.mock.calls[0][0]
expect(deadline.didTimeout).toBe(true)
expect(deadline.timeRemaining()).toBeGreaterThanOrEqual(0)
disposable.dispose()
disposable.dispose()
})
it('cancels fallback idle work before it runs', async () => {
vi.useFakeTimers()
vi.stubGlobal('requestIdleCallback', undefined)
vi.stubGlobal('cancelIdleCallback', undefined)
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
runWhenGlobalIdle(runner).dispose()
await vi.runAllTimersAsync()
expect(runner).not.toHaveBeenCalled()
})
it('uses native idle callbacks when available', async () => {
const requestIdleCallback = vi.fn(() => 42)
const cancelIdleCallback = vi.fn()
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
vi.stubGlobal('cancelIdleCallback', cancelIdleCallback)
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
const disposable = runWhenGlobalIdle(runner, 250)
expect(requestIdleCallback).toHaveBeenCalledWith(runner, { timeout: 250 })
disposable.dispose()
disposable.dispose()
expect(cancelIdleCallback).toHaveBeenCalledOnce()
expect(cancelIdleCallback).toHaveBeenCalledWith(42)
})
it('omits native idle timeout options when no timeout is supplied', async () => {
const requestIdleCallback = vi.fn(() => 7)
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
vi.stubGlobal('cancelIdleCallback', vi.fn())
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
runWhenGlobalIdle(runner)
expect(requestIdleCallback).toHaveBeenCalledWith(runner, undefined)
})
})

View File

@@ -122,6 +122,22 @@ describe('downloadUtil', () => {
expect(createObjectURLSpy).not.toHaveBeenCalled()
})
it('throws for an empty URL', () => {
expect(() => downloadFile('')).toThrow(
'Invalid URL provided for download'
)
expect(fetchMock).not.toHaveBeenCalled()
expect(createObjectURLSpy).not.toHaveBeenCalled()
})
it('throws for a whitespace URL', () => {
expect(() => downloadFile(' ')).toThrow(
'Invalid URL provided for download'
)
expect(fetchMock).not.toHaveBeenCalled()
expect(createObjectURLSpy).not.toHaveBeenCalled()
})
it('should prefer custom filename over extracted filename', () => {
const testUrl =
'https://example.com/api/file?filename=extracted-image.jpg'

View File

@@ -4,6 +4,7 @@ import {
CREDITS_PER_USD,
COMFY_CREDIT_RATE_CENTS,
centsToCredits,
clampUsd,
creditsToCents,
creditsToUsd,
formatCredits,
@@ -43,4 +44,21 @@ describe('comfyCredits helpers', () => {
expect(formatCreditsFromUsd({ usd: 1, locale })).toBe('211.00')
expect(formatUsd({ value: 4.2, locale })).toBe('4.20')
})
test('formats with compatible fraction digit bounds', () => {
expect(
formatCredits({
value: 12.345,
locale: 'en-US',
numberOptions: { minimumFractionDigits: 4, maximumFractionDigits: 2 }
})
).toBe('12.35')
})
test('clamps USD purchase values into the supported range', () => {
expect(clampUsd(Number.NaN)).toBe(0)
expect(clampUsd(-5)).toBe(1)
expect(clampUsd(42)).toBe(42)
expect(clampUsd(5000)).toBe(1000)
})
})

View File

@@ -37,7 +37,7 @@
size="unset"
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
data-testid="error-overlay-see-errors"
@click="seeErrors"
@click="viewErrorsInGraph"
>
{{
appMode
@@ -67,31 +67,18 @@ import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
const { appMode = false } = defineProps<{ appMode?: boolean }>()
const { t } = useI18n()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
const canvasStore = useCanvasStore()
const { viewErrorsInGraph } = useViewErrorsInGraph()
const { isVisible, overlayMessage, overlayTitle } = useErrorOverlayState()
function dismiss() {
executionErrorStore.dismissErrorOverlay()
}
function seeErrors() {
canvasStore.linearMode = false
if (canvasStore.canvas) {
canvasStore.canvas.deselectAll()
canvasStore.updateSelectedItems()
}
rightSidePanelStore.openPanel('errors')
executionErrorStore.dismissErrorOverlay()
}
</script>

View File

@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}

View File

@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
const subscriptionDialog = useSubscriptionDialog()
function handleClick() {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
}
</script>

View File

@@ -1,5 +1,6 @@
import type { ComputedRef, Ref } from 'vue'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type {
BillingStatus,
@@ -75,9 +76,10 @@ export interface BillingActions {
*/
requireActiveSubscription: () => Promise<void>
/**
* Shows the subscription dialog.
* Shows the subscription dialog. Pass a reason so the paywall open and any
* downstream checkout stay attributed to the triggering product moment.
*/
showSubscriptionDialog: () => void
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
}
export interface BillingState {

View File

@@ -7,6 +7,7 @@ import {
getTierFeatures
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
PreviewSubscribeOptions,
SubscribeOptions
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
return activeContext.value.requireActiveSubscription()
}
function showSubscriptionDialog() {
return activeContext.value.showSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
return activeContext.value.showSubscriptionDialog(options)
}
return {

View File

@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
import { useAuthActions } from '@/composables/auth/useAuthActions'
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingStatus,
BillingSubscriptionStatus,
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
legacyShowSubscriptionDialog()
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
legacyShowSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
legacyShowSubscriptionDialog(options)
}
return {

View File

@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}

View File

@@ -0,0 +1,105 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
import { LGraph, LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph'
import { createMockCanvasRenderingContext2D } from '@/utils/__tests__/litegraphTestUtils'
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
import { useViewErrorsInGraph } from './useViewErrorsInGraph'
const apiMock = vi.hoisted(() => ({
getSettings: vi.fn(),
storeSetting: vi.fn(),
storeSettings: vi.fn()
}))
vi.mock('@/scripts/api', () => ({
api: apiMock
}))
const appMock = vi.hoisted(() => ({
ui: {
settings: {
dispatchChange: vi.fn()
}
},
rootGraph: {
events: new EventTarget(),
nodes: []
}
}))
vi.mock('@/scripts/app', () => ({
app: appMock
}))
function createSelectedCanvas() {
const graph = new LGraph()
const canvasElement = document.createElement('canvas')
canvasElement.width = 800
canvasElement.height = 600
canvasElement.getContext = vi
.fn()
.mockReturnValue(createMockCanvasRenderingContext2D())
const canvas = new LGraphCanvas(canvasElement, graph, {
skip_events: true,
skip_render: true
})
const node = new LGraphNode('Selected Node')
graph.add(node)
canvas.selectedItems.add(node)
node.selected = true
return { canvas, node }
}
describe('useViewErrorsInGraph', () => {
beforeEach(() => {
vi.clearAllMocks()
setActivePinia(createPinia())
apiMock.getSettings.mockResolvedValue({})
apiMock.storeSetting.mockResolvedValue(undefined)
apiMock.storeSettings.mockResolvedValue(undefined)
})
it('opens graph errors and clears app-mode error UI state', () => {
const canvasStore = useCanvasStore()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
const workflowStore = useWorkflowStore()
const { canvas, node } = createSelectedCanvas()
workflowStore.activeWorkflow = {
activeMode: 'app'
} as typeof workflowStore.activeWorkflow
canvasStore.canvas = canvas
canvasStore.selectedItems = [node]
executionErrorStore.showErrorOverlay()
useViewErrorsInGraph().viewErrorsInGraph()
expect(node.selected).toBe(false)
expect(canvasStore.linearMode).toBe(false)
expect(canvasStore.selectedItems).toEqual([])
expect(rightSidePanelStore.activeTab).toBe('errors')
expect(rightSidePanelStore.isOpen).toBe(true)
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
})
it('opens graph errors when the canvas is not initialized', () => {
const canvasStore = useCanvasStore()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
canvasStore.canvas = null
executionErrorStore.showErrorOverlay()
expect(() => useViewErrorsInGraph().viewErrorsInGraph()).not.toThrow()
expect(rightSidePanelStore.activeTab).toBe('errors')
expect(rightSidePanelStore.isOpen).toBe(true)
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
})
})

View File

@@ -0,0 +1,22 @@
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
export function useViewErrorsInGraph() {
const canvasStore = useCanvasStore()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
function viewErrorsInGraph() {
canvasStore.linearMode = false
if (canvasStore.canvas) {
canvasStore.canvas.deselectAll()
canvasStore.updateSelectedItems()
}
rightSidePanelStore.openPanel('errors')
executionErrorStore.dismissErrorOverlay()
}
return { viewErrorsInGraph }
}

View File

@@ -25,6 +25,6 @@ function handleClose() {
}
function handleSubscribe() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
}
</script>

View File

@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
// Shows loading affordances
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
})
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
'team_700',
'yearly'
'yearly',
{ paymentIntentSource: 'deep_link' }
)
// Team never goes through the personal checkout path
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()

View File

@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
return
}
isTeamCheckout.value = true
await performTeamSubscriptionCheckout(stopId, billingCycle)
await performTeamSubscriptionCheckout(stopId, billingCycle, {
paymentIntentSource: 'deep_link'
})
return
}
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
if (isActiveSubscription.value) {
await accessBillingPortal(undefined, false)
} else {
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
openInNewTab: false,
paymentIntentSource: 'deep_link'
})
}
}, reportError)

View File

@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
})
function handleAddCredits() {
telemetry?.trackAddApiCreditButtonClicked()
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
void dialogService.showTopUpCreditsDialog()
}
function handleUpgradeToAddCredits() {
showPricingTable()
showPricingTable({ reason: 'upgrade_to_add_credits' })
}
async function handleWindowFocus() {

View File

@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import FreeTierDialogContent from './FreeTierDialogContent.vue'
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
}))
}))
function renderComponent() {
function renderComponent(props?: { reason?: PaymentIntentSource }) {
const i18n = createI18n({
legacy: false,
locale: 'en',
@@ -23,6 +25,7 @@ function renderComponent() {
})
return render(FreeTierDialogContent, {
props,
global: {
plugins: [i18n]
}
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
renderComponent()
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
it('keeps the generic copy for intent reasons outside the credits variants', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'subscribe_to_run' })
expect(
screen.getByText('Your credits refresh on Jul 15, 2026.')
).toBeInTheDocument()
})
it('swaps to the out-of-credits copy without the refresh line', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'out_of_credits' })
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
})

View File

@@ -52,7 +52,7 @@
</p>
<p
v-if="!reason || reason === 'subscription_required'"
v-if="!isCreditsBlockedVariant"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -65,10 +65,7 @@
</p>
<p
v-if="
(!reason || reason === 'subscription_required') &&
formattedRenewalDate
"
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -88,7 +85,7 @@
@click="$emit('upgrade')"
>
{{
reason === 'out_of_credits' || reason === 'top_up_blocked'
isCreditsBlockedVariant
? $t('subscription.freeTier.upgradeCta')
: $t('subscription.freeTier.subscribeCta')
}}
@@ -103,12 +100,12 @@ import { computed } from 'vue'
import Button from '@/components/ui/button/Button.vue'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
defineProps<{
reason?: SubscriptionDialogReason
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
defineEmits<{
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
})
const freeTierCredits = computed(() => getTierCredits('free'))
// Only these two variants replace the generic free-tier copy; any other
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
const isCreditsBlockedVariant = computed(
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
)
</script>

View File

@@ -261,6 +261,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('should use the latest userId value when it changes after mount', async () => {
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
})

View File

@@ -277,13 +277,19 @@ import type {
TierKey,
TierPricing
} from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import {
recordPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { useAuthStore } from '@/stores/authStore'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
@@ -321,6 +327,10 @@ interface PricingTierConfig {
isPopular?: boolean
}
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
const emit = defineEmits<{
chooseTeamWorkspace: []
}>()
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
} as const
const previousPlan = currentPlanDescriptor.value
const checkoutAttribution = await getCheckoutAttributionForCloud()
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
})
}
const beginCheckoutMetadata = userId.value
? {
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change' as const,
...(reason ? { payment_intent_source: reason } : {}),
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
}
: null
// Pass the target tier to create a deep link to subscription update confirmation
const checkoutTier = getCheckoutTier(
targetPlan.tierKey,
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
if (downgrade) {
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
await accessBillingPortal()
const didOpenPortal = await accessBillingPortal()
if (didOpenPortal && beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
}
} else {
const didOpenPortal = await accessBillingPortal(checkoutTier)
if (!didOpenPortal) {
return
}
recordPendingSubscriptionCheckoutAttempt({
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
payment_intent_source: reason,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
...(previousPlan
? { previous_cycle: previousPlan.billingCycle }
: {})
})
if (beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
beginCheckoutMetadata,
pendingAttempt
)
)
}
}
} else {
await performSubscriptionCheckout(
tierKey,
currentBillingCycle.value,
true
)
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
paymentIntentSource: reason
})
}
} finally {
isLoading.value = false

View File

@@ -56,7 +56,7 @@ const handleSubscribe = () => {
current_tier: tier.value?.toLowerCase()
})
isAwaitingStripeSubscription.value = true
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_now_button' })
}
onBeforeUnmount(() => {

View File

@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
trackRunButton({ subscribe_to_run: true })
}
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
}
</script>

View File

@@ -48,7 +48,9 @@
v-if="isActiveSubscription"
variant="primary"
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
@click="showSubscriptionDialog"
@click="
showSubscriptionDialog({ reason: 'settings_billing_panel' })
"
>
{{ $t('subscription.upgradePlan') }}
</Button>

View File

@@ -33,7 +33,11 @@
</i18n-t>
</div>
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
<PricingTable
:reason
class="flex-1"
@choose-team-workspace="handleChooseTeam"
/>
<!-- Contact and Enterprise Links -->
<div class="flex flex-col items-center gap-2">
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import { useCommandStore } from '@/stores/commandStore'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
const { onClose, reason, onChooseTeam } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
onChooseTeam?: () => void
}>()

View File

@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
)
return
case 'subscription':
void dialogService.showSubscriptionRequiredDialog()
void dialogService.showSubscriptionRequiredDialog({
reason: 'subscription_required'
})
return
case 'credits':
void dialogService.showTopUpCreditsDialog({

View File

@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
})
}))
const mockTrackSubscription = vi.hoisted(() => vi.fn())
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
describe('usePricingTableUrlLoader', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
reason: 'deep_link',
planMode: undefined
})
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
reason: 'deep_link'
})
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
})
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('denies, strips, and clears together when the user is not eligible', async () => {
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({
query: { other: 'param' }
})
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
)
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
'pricing'

View File

@@ -7,7 +7,6 @@ import {
mergePreservedQueryIntoQuery
} from '@/platform/navigation/preservedQueryManager'
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
import { useTelemetry } from '@/platform/telemetry'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
const planMode =
param === 'team' || param === 'personal' ? param : undefined
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
}

View File

@@ -15,7 +15,7 @@ import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import { useDialogService } from '@/services/dialogService'
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
})
}, reportError)
const showSubscriptionDialog = (options?: {
reason?: SubscriptionDialogReason
}) => {
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: subscriptionTier.value?.toLowerCase(),
reason: options?.reason
})
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
void showSubscriptionRequiredDialog(options)
}
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
await fetchSubscriptionStatus()
if (!isSubscribedOrIsNotCloud.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscription_required' })
}
}

View File

@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
}))
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
const {
mockIsCloud,
mockTrackHelpResourceClicked,
mockTrackAddApiCreditButtonClicked
} = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockTrackHelpResourceClicked: vi.fn()
mockTrackHelpResourceClicked: vi.fn(),
mockTrackAddApiCreditButtonClicked: vi.fn()
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () =>
mockIsCloud.value
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
? {
trackHelpResourceClicked: mockTrackHelpResourceClicked,
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
}
: null
}))
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
const { handleAddApiCredits } = useSubscriptionActions()
handleAddApiCredits()
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
source: 'settings_billing_panel'
})
})
})

View File

@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
})
const handleAddApiCredits = () => {
telemetry?.trackAddApiCreditButtonClicked({
source: 'settings_billing_panel'
})
void dialogService.showTopUpCreditsDialog()
}

View File

@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
const mockCloseDialog = vi.fn()
const mockShowLayoutDialog = vi.fn()
const mockShowTeamWorkspacesDialog = vi.fn()
const mockTrackSubscription = vi.hoisted(() => vi.fn())
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
const mockIsCloud = vi.hoisted(() => ({ value: true }))
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isFreeTier: mockIsFreeTier,
isLegacyTeamPlan: mockIsLegacyTeamPlan
isLegacyTeamPlan: mockIsLegacyTeamPlan,
tier: mockTier
})
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
useWorkspaceUI: () => ({
permissions: {
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
mockIsCloud.value = true
mockIsInPersonalWorkspace.value = true
mockIsFreeTier.value = false
mockTier.value = 'FREE'
mockTeamWorkspacesEnabled.value = false
mockIsLegacyTeamPlan.value = false
mockCanManageSubscription.value = true
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
const props = mockShowLayoutDialog.mock.calls[0][0].props
expect(props.initialPlanMode).toBe('team')
})
it('tracks modal_opened with the caller reason and current tier', () => {
mockTier.value = 'STANDARD'
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'upgrade_to_add_credits' })
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
current_tier: 'standard',
reason: 'upgrade_to_add_credits'
})
})
it('tracks modal_opened on the workspace (unified) path too', () => {
mockTeamWorkspacesEnabled.value = true
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'subscribe_to_run' })
)
})
it('does not track modal_opened for the inactive member dialog', () => {
mockTeamWorkspacesEnabled.value = true
mockIsInPersonalWorkspace.value = false
mockCanManageSubscription.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('does not track on non-cloud', () => {
mockIsCloud.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
})
describe('show', () => {
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
expect.objectContaining({ key: 'subscription-required' })
)
})
it('tracks modal_opened with the reason for the free-tier dialog', () => {
mockIsFreeTier.value = true
mockIsInPersonalWorkspace.value = true
const { show } = useSubscriptionDialog()
show({ reason: 'out_of_credits' })
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'out_of_credits' })
)
})
})
describe('startTeamWorkspaceUpgradeFlow', () => {

View File

@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
export type SubscriptionDialogReason =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
interface SubscriptionDialogOptions {
reason?: SubscriptionDialogReason
export interface SubscriptionDialogOptions {
reason?: PaymentIntentSource
/**
* Forces the unified pricing dialog to open on a specific plan tab,
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
}
// Fired here — the choke point every paywall/pricing dialog variant passes
// through — so both the legacy and workspace billing paths emit it.
function trackModalOpened(reason?: PaymentIntentSource) {
// Resolved lazily to avoid the useBillingContext import cycle (see below).
const { tier } = useBillingContext()
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: tier.value?.toLowerCase(),
reason
})
}
function showPricingTable(options?: SubscriptionDialogOptions) {
if (!isCloud) return
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
return
}
trackModalOpened(options?.reason)
// Shared dialog shell styling for both variants.
const dialogComponentProps = {
style: 'width: min(1328px, 95vw); max-height: 958px;',
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
// (not at composable setup) to avoid the useBillingContext import cycle.
const { isFreeTier } = useBillingContext()
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
trackModalOpened(options?.reason)
const component = defineAsyncComponent(
() =>
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
sessionStorage.removeItem(RESUME_PRICING_KEY)
if (!workspaceStore.isInPersonalWorkspace) {
showPricingTable()
showPricingTable({ reason: 'team_upgrade_resume' })
}
} catch {
// sessionStorage may be unavailable

View File

@@ -0,0 +1,49 @@
import { beforeEach, describe, expect, it } from 'vitest'
import {
clearPendingSubscriptionCheckoutAttempt,
consumePendingSubscriptionCheckoutSuccess,
recordPendingSubscriptionCheckoutAttempt
} from './subscriptionCheckoutTracker'
const activeProStatus = {
is_active: true,
subscription_tier: 'PRO',
subscription_duration: 'MONTHLY'
} as const
describe('subscriptionCheckoutTracker', () => {
beforeEach(() => {
clearPendingSubscriptionCheckoutAttempt()
})
it('round-trips payment_intent_source from attempt to success metadata', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).toMatchObject({
tier: 'pro',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
})
it('omits payment_intent_source when the attempt had none', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).not.toBeNull()
expect(metadata).not.toHaveProperty('payment_intent_source')
})
})

View File

@@ -7,7 +7,12 @@ import type {
TierKey
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
import type {
BeginCheckoutMetadata,
PaymentIntentSource,
SubscriptionCheckoutType,
SubscriptionSuccessMetadata
} from '@/platform/telemetry/types'
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
const VALID_TIER_KEYS = new Set<TierKey>([
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
'comfy:subscription-checkout-attempt-changed'
type CheckoutType = 'new' | 'change'
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
interface SubscriptionStatusSnapshot {
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
subscription_duration?: SubscriptionDuration | null
}
interface PendingSubscriptionCheckoutAttempt {
export interface PendingSubscriptionCheckoutAttempt {
attempt_id: string
started_at_ms: number
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
interface RecordPendingSubscriptionCheckoutAttemptInput {
interface PendingSubscriptionCheckoutAttemptInput {
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
const dispatchPendingCheckoutChangeEvent = () => {
@@ -168,6 +174,9 @@ const normalizeAttempt = (
...(candidate.previous_cycle === 'monthly' ||
candidate.previous_cycle === 'yearly'
? { previous_cycle: candidate.previous_cycle }
: {}),
...(typeof candidate.payment_intent_source === 'string'
? { payment_intent_source: candidate.payment_intent_source }
: {})
}
}
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
getPendingSubscriptionCheckoutAttempt() !== null
export const recordPendingSubscriptionCheckoutAttempt = (
input: RecordPendingSubscriptionCheckoutAttemptInput
export const createPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
const attempt: PendingSubscriptionCheckoutAttempt = {
return {
attempt_id: createAttemptId(),
started_at_ms: Date.now(),
tier: input.tier,
cycle: input.cycle,
checkout_type: input.checkout_type,
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
...(input.payment_intent_source
? { payment_intent_source: input.payment_intent_source }
: {})
}
}
export const persistPendingSubscriptionCheckoutAttempt = (
attempt: PendingSubscriptionCheckoutAttempt
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
if (!storage) {
return attempt
}
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
return attempt
}
export const recordPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt =>
persistPendingSubscriptionCheckoutAttempt(
createPendingSubscriptionCheckoutAttempt(input)
)
export const withPendingCheckoutAttemptId = (
metadata: BeginCheckoutMetadata,
attempt: PendingSubscriptionCheckoutAttempt
): BeginCheckoutMetadata => ({
...metadata,
checkout_attempt_id: attempt.attempt_id
})
const didAttemptSucceed = (
attempt: PendingSubscriptionCheckoutAttempt,
status: SubscriptionStatusSnapshot
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
cycle: attempt.cycle,
checkout_type: attempt.checkout_type,
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
...(attempt.payment_intent_source
? { payment_intent_source: attempt.payment_intent_source }
: {}),
value,
currency: 'USD',
ecommerce: {

View File

@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'yearly', true)
await performSubscriptionCheckout('pro', 'yearly')
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-123',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new',
checkout_attempt_id: expect.any(String),
ga_client_id: 'ga-client-id',
ga_session_id: 'ga-session-id',
ga_session_number: 'ga-session-number',
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
gbraid: 'gbraid-456',
wbraid: 'wbraid-789'
})
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
JSON.parse(storedAttempt).attempt_id
)
expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining(
'/customers/cloud-subscription-checkout/pro-yearly'
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(warnSpy).toHaveBeenCalledWith(
'[SubscriptionCheckout] Failed to collect checkout attribution',
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-123',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
.spyOn(window, 'open')
.mockImplementation(() => window as unknown as Window)
vi.mocked(global.fetch).mockResolvedValue({
ok: true,
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', {
paymentIntentSource: 'out_of_credits'
})
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
)
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
const pendingAttempt = JSON.parse(storedAttempt)
expect(pendingAttempt).toMatchObject({
payment_intent_source: 'out_of_credits'
})
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
pendingAttempt.attempt_id
)
openSpy.mockRestore()
})
it('uses the latest userId when it changes after checkout starts', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
mockUserId.value = 'user-late'
authHeader.resolve({ Authorization: 'Bearer test-token' })
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-late',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
)
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
const storedAttempt = window.localStorage.getItem(
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
)
expect(storedAttempt).toBeNull()
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
checkout_attempt_id: expect.any(String)
})
)
})
})

View File

@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { getComfyApiBaseUrl } from '@/config/comfyApi'
import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import {
createPendingSubscriptionCheckoutAttempt,
persistPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import type { BillingCycle } from './subscriptionTierRank'
type CheckoutTier = TierKey | `${TierKey}-yearly`
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
return getCheckoutAttribution()
}
interface PerformSubscriptionCheckoutOptions {
openInNewTab?: boolean
paymentIntentSource?: PaymentIntentSource
}
/**
* Core subscription checkout logic shared between PricingTable and
* SubscriptionRedirectView. Handles:
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
export async function performSubscriptionCheckout(
tierKey: TierKey,
currentBillingCycle: BillingCycle,
openInNewTab: boolean = true
options: PerformSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
const { openInNewTab = true, paymentIntentSource } = options
const authStore = useAuthStore()
const { userId } = storeToRefs(authStore)
const telemetry = useTelemetry()
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
const data = await response.json()
if (data.checkout_url) {
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
payment_intent_source: paymentIntentSource
})
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...checkoutAttribution
})
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
{
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {}),
...checkoutAttribution
},
pendingAttempt
)
)
}
if (openInNewTab) {
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
if (!checkoutWindow) {
return
}
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
} else {
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
globalThis.location.href = data.checkout_url
}
}

View File

@@ -1,9 +1,13 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed, reactive } from 'vue'
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn()
}))
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
workspaceApi: { subscribe: mockSubscribe }
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
billing_op_id: 'op_1'
})
await performTeamSubscriptionCheckout('team_700', 'yearly')
await performTeamSubscriptionCheckout('team_700', 'yearly', {
paymentIntentSource: 'deep_link'
})
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
returnUrl: 'https://app.test/payment/success',
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
teamCreditStopId: 'team_700'
})
expect(assignedHref).toBe('https://stripe.test/pay')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'team',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op_1',
payment_intent_source: 'deep_link'
})
})
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
expect(assignedHref).toBeUndefined()
})
it('does not track begin_checkout when subscribe fails', async () => {
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
await expect(
performTeamSubscriptionCheckout('team_700', 'yearly')
).rejects.toThrow('subscribe failed')
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('does nothing off cloud', async () => {
mockIsCloud.value = false

View File

@@ -1,10 +1,16 @@
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
import { isCloud } from '@/platform/distribution/types'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
import type { BillingCycle } from './subscriptionTierRank'
interface PerformTeamSubscriptionCheckoutOptions {
paymentIntentSource?: PaymentIntentSource
}
/**
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
* link: subscribes to the per-credit Team plan at the chosen slider stop and
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
*/
export async function performTeamSubscriptionCheckout(
teamCreditStopId: string,
billingCycle: BillingCycle
billingCycle: BillingCycle,
options: PerformTeamSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
teamCreditStopId
})
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType: 'new',
billingOpId: response.billing_op_id,
paymentIntentSource: options.paymentIntentSource
})
if (response.status === 'needs_payment_method') {
// A needs_payment_method response without a URL is unusable: surface it to
// the caller's error handling rather than silently dropping the user home

View File

@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
})
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
const b: TelemetryProvider = {}
const registry = new TelemetryRegistry()
registry.registerProvider(a)
registry.registerProvider(b)
const metadata = {
user_id: 'user-1',
tier: 'pro' as const,
cycle: 'monthly' as const,
checkout_type: 'new' as const,
payment_intent_source: 'subscribe_to_run' as const
}
registry.trackBeginCheckout(metadata)
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
})
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
const provider: TelemetryProvider = {
trackAddApiCreditButtonClicked: vi.fn()
}
const registry = new TelemetryRegistry()
registry.registerProvider(provider)
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(
provider.trackAddApiCreditButtonClicked
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
})
it('skips providers that do not implement trackSearchQuery', () => {
const empty: TelemetryProvider = {}
const registry = new TelemetryRegistry()

View File

@@ -1,6 +1,7 @@
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
}
trackAddApiCreditButtonClicked(): void {
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.dispatch((provider) =>
provider.trackAddApiCreditButtonClicked?.(metadata)
)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
)
})
it('captures begin_checkout with intent metadata', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackBeginCheckout({
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.BEGIN_CHECKOUT,
{
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
}
)
})
it('captures add-credit clicks with their source', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'credits_panel' }
)
})
it('captures share attribution events', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()

View File

@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
import type { RemoteConfig } from '@/platform/remoteConfig/types'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
EnterLinearMetadata,
ShareFlowMetadata,
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
this.trackEvent(eventName, metadata)
}
trackAddApiCreditButtonClicked(): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
}
trackMonthlySubscriptionSucceeded(

View File

@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
)
})
it('forwards add-credit clicks with their source', () => {
new HostTelemetrySink().trackAddApiCreditButtonClicked({
source: 'avatar_menu'
})
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'avatar_menu' }
)
})
it('does nothing when the host bridge is absent', () => {
delete window.__comfyDesktop2

View File

@@ -10,6 +10,7 @@ import {
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
}
trackAddApiCreditButtonClicked(): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -12,12 +12,29 @@
* 3. Check dist/assets/*.js files contain no tracking code
*/
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { AuditLog } from '@/services/customerEventsService'
import type { AppMode } from '@/utils/appMode'
export type PaymentIntentSource =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
| 'subscribe_to_run'
| 'subscribe_now_button'
| 'upgrade_to_add_credits'
| 'settings_billing_panel'
| 'avatar_menu_plans'
| 'team_members_panel'
| 'invite_member_upsell'
| 'upload_model_upgrade'
| 'team_upgrade_resume'
export type SubscriptionCheckoutType = 'new' | 'change'
export type SubscriptionCheckoutTier = TierKey | 'team'
/**
* Authentication metadata for sign-up tracking
*/
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
export interface SubscriptionMetadata {
current_tier?: string
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
}
export interface AddCreditsClickMetadata {
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
}
export interface BeginCheckoutMetadata
extends Record<string, unknown>, CheckoutAttributionMetadata {
user_id: string
tier: TierKey
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
checkout_attempt_id?: string
billing_op_id?: string
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
}
interface EcommerceItemMetadata {
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
checkout_attempt_id: string
tier: TierKey
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
value: number
currency: string
ecommerce: EcommerceMetadata
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
metadata?: SubscriptionSuccessMetadata
): void
trackMonthlySubscriptionCancelled?(): void
trackAddApiCreditButtonClicked?(): void
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
trackApiCreditTopupSucceeded?(): void
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void

View File

@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}

View File

@@ -391,12 +391,13 @@ const showZeroState = computed(
)
function handleSubscribeWorkspace() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleUpgrade() {
if (isFreeTierPlan.value) showPricingTable()
else showSubscriptionDialog()
if (isFreeTierPlan.value)
showPricingTable({ reason: 'settings_billing_panel' })
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleViewMoreDetails() {

View File

@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
import { useEventListener } from '@vueuse/core'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
const { onClose, reason, initialPlanMode } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
initialPlanMode?: 'personal' | 'team'
}>()
@@ -152,7 +152,7 @@ const {
handleConfirmTransition,
handleTeamSubscribe,
handleResubscribe
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
// Backspace mirrors the back arrow on the confirm step, but never while an
// editable element is focused (let it delete text there).

View File

@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
import { createI18n } from 'vue-i18n'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
const mockHandleSuccessClose = vi.fn()
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
const mockPreviewData = ref<{ transition_type: string } | null>(null)
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
useSubscriptionCheckout: () => ({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
useSubscriptionCheckout: mockUseSubscriptionCheckout
}))
const i18n = createI18n({
@@ -91,7 +76,7 @@ const SuccessStub = {
function renderComponent(
props: {
onClose?: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
} = {}
) {
@@ -121,6 +106,23 @@ function renderComponent(
describe('SubscriptionRequiredDialogContentWorkspace', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseSubscriptionCheckout.mockReturnValue({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
mockCheckoutStep.value = 'pricing'
mockPreviewData.value = null
})
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
})
it('passes the reason into subscription checkout', () => {
renderComponent({ reason: 'out_of_credits' })
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
expect.any(Function),
'out_of_credits'
)
})
it('shows the team workspace header by default', () => {
renderComponent()
expect(screen.getByText('Team Workspace')).toBeInTheDocument()

View File

@@ -116,7 +116,7 @@
import { cn } from '@comfyorg/tailwind-utils'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import PricingTableWorkspace from './PricingTableWorkspace.vue'
@@ -130,7 +130,7 @@ const {
isPersonal = false
} = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
}>()
@@ -154,7 +154,7 @@ const {
handleConfirmTransition,
handleResubscribe,
handleSuccessClose
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
</script>
<style scoped>

View File

@@ -61,6 +61,9 @@ function onDismiss() {
function onUpgrade() {
dialogStore.closeDialog({ key: 'invite-member-upsell' })
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({
planMode: 'team',
reason: 'invite_member_upsell'
})
}
</script>

View File

@@ -277,7 +277,7 @@ export function useMembersPanel() {
}
function showTeamPlans() {
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
}
return {

View File

@@ -1,8 +1,9 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed } from 'vue'
import { computed, reactive } from 'vue'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import type { Plan } from '@/platform/workspace/api/workspaceApi'
import { findPlanSlug } from './useSubscriptionCheckout'
@@ -75,7 +76,9 @@ const {
mockPlans,
mockResubscribe,
mockToastAdd,
mockStartOperation
mockStartOperation,
mockTrackBeginCheckout,
mockUserId
} = vi.hoisted(() => ({
mockSubscribe: vi.fn(),
mockPreviewSubscribe: vi.fn(),
@@ -84,7 +87,9 @@ const {
mockPlans: { value: [] as Plan[] },
mockResubscribe: vi.fn(),
mockToastAdd: vi.fn(),
mockStartOperation: vi.fn()
mockStartOperation: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
useTelemetry: () => ({
trackMonthlySubscriptionSucceeded: vi.fn(),
trackBeginCheckout: mockTrackBeginCheckout
})
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
vi.mock('vue-i18n', async (importOriginal) => {
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
describe('useSubscriptionCheckout', () => {
let emit: ReturnType<typeof vi.fn>
async function setup() {
async function setup(paymentIntentSource?: PaymentIntentSource) {
const { useSubscriptionCheckout } =
await import('./useSubscriptionCheckout')
return useSubscriptionCheckout(emit as never)
return useSubscriptionCheckout(emit as never, paymentIntentSource)
}
beforeEach(() => {
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
vi.clearAllMocks()
mockPlans.value = allPlans()
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
mockUserId.value = 'user-1'
emit = vi.fn()
})
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
cancelUrl: 'https://platform.comfy.org/payment/failed'
})
expect(checkout.checkoutStep.value).toBe('success')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
checkout_type: 'new',
billing_op_id: 'op-team-1'
})
)
})
it('uses the annual plan slug for the yearly cycle', async () => {
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
detail: 'Team payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('keeps team checkout_type as change when the preview request fails', async () => {
const checkout = await setup()
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
await checkout.handleSubscribeTeamClick({
stop: {
id: 'team_1400',
usd: 1400,
credits: 295_400,
discountedUsd: 1295
},
billingCycle: 'monthly',
isChange: true
})
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-team-change'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleTeamSubscribe()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
cycle: 'monthly',
checkout_type: 'change',
billing_op_id: 'op-team-change'
})
)
})
})
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
expect(checkout.checkoutStep.value).toBe('success')
})
it('skips begin_checkout when no user id is available', async () => {
mockUserId.value = null
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
mockUserId.value = 'user-1'
})
it('fires begin_checkout carrying the payment intent source', async () => {
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'standard',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op-1',
payment_intent_source: 'subscribe_to_run'
})
})
it('opens payment URL when needs_payment_method', async () => {
const checkout = await setup()
checkout.selectedTierKey.value = 'standard'
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
detail: 'Payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
})

View File

@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type {
Plan,
PreviewSubscribeResponse,
SubscribeResponse
} from '@/platform/workspace/api/workspaceApi'
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
type CheckoutStep = 'pricing' | 'preview' | 'success'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
interface SelectedTeamCheckout {
stop: TeamPlanSelection
checkoutType: SubscriptionCheckoutType
}
/**
* Which screen the `preview` step shows. Only a change prorates: a team change
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
@@ -45,9 +55,12 @@ export function findPlanSlug(
return plan?.slug ?? null
}
export function useSubscriptionCheckout(emit: {
(e: 'close', subscribed: boolean): void
}) {
export function useSubscriptionCheckout(
emit: {
(e: 'close', subscribed: boolean): void
},
paymentIntentSource?: PaymentIntentSource
) {
const { t } = useI18n()
const toast = useToast()
const {
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
const isResubscribing = ref(false)
const previewData = ref<PreviewSubscribeResponse | null>(null)
const selectedTierKey = ref<CheckoutTierKey | null>(null)
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
const selectedBillingCycle = ref<BillingCycle>('yearly')
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
const selectedTeamStop = computed(
() => selectedTeamCheckout.value?.stop ?? null
)
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
const previewVariant = computed<PreviewVariant>(() => {
if (selectedTeamStop.value) {
if (selectedTeamCheckout.value) {
return previewData.value ? 'team-change' : 'team-new'
}
if (previewData.value) {
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
billingCycle: BillingCycle
isChange?: boolean
}) {
selectedTeamStop.value = payload.stop
selectedTeamCheckout.value = {
stop: payload.stop,
checkoutType: payload.isChange ? 'change' : 'new'
}
selectedBillingCycle.value = payload.billingCycle
selectedTierKey.value = null
previewData.value = null
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
function handleBackToPricing() {
checkoutStep.value = 'pricing'
previewData.value = null
selectedTeamStop.value = null
selectedTeamCheckout.value = null
}
function handleSuccessClose() {
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
}
async function handleSubscription() {
if (!selectedTierKey.value) return
const tierKey = selectedTierKey.value
if (!tierKey) return
const billingCycle = selectedBillingCycle.value
const checkoutType =
previewData.value &&
previewData.value.transition_type !== 'new_subscription'
? 'change'
: 'new'
isSubscribing.value = true
try {
const planSlug = getApiPlanSlug(
selectedTierKey.value,
selectedBillingCycle.value
)
const planSlug = getApiPlanSlug(tierKey, billingCycle)
if (!planSlug) return
const response = await subscribe(planSlug, {
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: tierKey,
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
}
async function handleTeamSubscription() {
const stop = selectedTeamStop.value
if (!stop?.id) {
const teamCheckout = selectedTeamCheckout.value
if (!teamCheckout?.stop.id) {
toast.add({
severity: 'error',
summary: t('subscription.teamPlan.name'),
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
return
}
const { stop, checkoutType } = teamCheckout
const billingCycle = selectedBillingCycle.value
isSubscribing.value = true
try {
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
const planSlug = getTeamPlanSlug(billingCycle)
const response = await subscribe(planSlug, {
teamCreditStopId: stop.id,
billingCycle: selectedBillingCycle.value,
billingCycle,
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)

View File

@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingBalanceResponse,
BillingStatusResponse,
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
subscriptionDialog.show()
subscriptionDialog.show({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
subscriptionDialog.show()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
subscriptionDialog.show(options)
}
return {

View File

@@ -0,0 +1,38 @@
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutTier,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useAuthStore } from '@/stores/authStore'
interface TrackWorkspaceCheckoutStartedOptions {
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkoutType: SubscriptionCheckoutType
billingOpId: string
paymentIntentSource?: PaymentIntentSource
}
export function trackWorkspaceCheckoutStarted({
tier,
cycle,
checkoutType,
billingOpId,
paymentIntentSource
}: TrackWorkspaceCheckoutStartedOptions) {
const { userId } = useAuthStore()
if (!userId) return
useTelemetry()?.trackBeginCheckout({
user_id: userId,
tier,
cycle,
checkout_type: checkoutType,
billing_op_id: billingOpId,
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {})
})
}

View File

@@ -0,0 +1,208 @@
import { createTestingPinia } from '@pinia/testing'
import { render, screen, within } from '@testing-library/vue'
import { setActivePinia } from 'pinia'
import { createI18n } from 'vue-i18n'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { NodeError } from '@/schemas/apiSchema'
import LinearControls from '@/renderer/extensions/linearMode/LinearControls.vue'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
import { useAppModeStore } from '@/stores/appModeStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { toNodeId } from '@/types/nodeId'
const billingMock = vi.hoisted(() => ({
isActiveSubscription: true
}))
const overlayMock = vi.hoisted(() => ({
overlayMessage: 'KSampler is missing a required input: model',
overlayTitle: 'Required input missing'
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isActiveSubscription: billingMock.isActiveSubscription
})
}))
vi.mock('@/components/error/useErrorOverlayState', () => ({
useErrorOverlayState: () => ({
overlayMessage: overlayMock.overlayMessage,
overlayTitle: overlayMock.overlayTitle
})
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: {
en: {
linearMode: {
error: {
goto: 'Show errors in graph'
},
mobileNoWorkflow: 'No workflow',
runCount: 'Run count',
viewJob: 'View job'
},
menu: {
run: 'Run'
},
menuLabels: {
publish: 'Publish'
},
queue: {
jobAddedToQueue: 'Job added to queue',
jobQueueing: 'Queueing'
}
}
}
})
const nodeErrors: Record<string, NodeError> = {
'1': {
class_type: 'TestNode',
dependent_outputs: [],
errors: [
{
type: 'required_input_missing',
message: 'Missing input',
details: '',
extra_info: { input_name: 'prompt' }
}
]
}
}
function renderControls({
hasError = false,
isActiveSubscription = true,
mobile = false
}: {
hasError?: boolean
isActiveSubscription?: boolean
mobile?: boolean
} = {}) {
billingMock.isActiveSubscription = isActiveSubscription
const pinia = createTestingPinia({
createSpy: vi.fn,
stubActions: false
})
setActivePinia(pinia)
useAppModeStore().selectedOutputs = [toNodeId(1)]
if (hasError) {
useExecutionErrorStore().lastNodeErrors = nodeErrors
}
const toastTarget = document.createElement('div')
return render(LinearControls, {
props: { mobile, toastTo: toastTarget },
global: {
plugins: [pinia, i18n],
stubs: {
AppModeWidgetList: true,
Loader: true,
PartnerNodesList: true,
Popover: {
template: '<div><slot name="button" /><slot /></div>'
},
ScrubableNumberInput: true,
SubscribeToRunButton: true
}
}
})
}
describe('LinearControls', () => {
beforeEach(() => {
vi.clearAllMocks()
billingMock.isActiveSubscription = true
overlayMock.overlayMessage = 'KSampler is missing a required input: model'
overlayMock.overlayTitle = 'Required input missing'
})
it.for([
{ label: 'desktop', mobile: false },
{ label: 'mobile', mobile: true }
])('shows a workflow error warning in $label controls', ({ mobile }) => {
renderControls({ hasError: true, mobile })
const warning = screen.getByRole('status')
expect(
within(warning).getByText('Required input missing')
).toBeInTheDocument()
expect(
within(warning).getByText('KSampler is missing a required input: model')
).toBeInTheDocument()
expect(
within(warning).getByRole('button', { name: 'Show errors in graph' })
).toBeInTheDocument()
expect(within(warning).queryByLabelText('Close')).not.toBeInTheDocument()
const runButton = screen.getByRole('button', { name: 'Run' })
expect(runButton).toHaveAttribute(
'aria-describedby',
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
)
const description = screen.getByTestId(
'linear-validation-warning-description'
)
expect(description).toHaveAttribute(
'id',
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
)
expect(description).toHaveTextContent('Required input missing')
expect(description).toHaveTextContent(
'KSampler is missing a required input: model'
)
expect(description).not.toHaveTextContent('Show errors in graph')
})
it.for([
{ label: 'desktop', mobile: false },
{ label: 'mobile', mobile: true }
])(
'does not show the workflow error warning in $label controls without graph errors',
({ mobile }) => {
renderControls({ mobile })
expect(screen.queryByRole('status')).not.toBeInTheDocument()
expect(
screen.queryByRole('button', { name: 'Show errors in graph' })
).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Run' })).not.toHaveAttribute(
'aria-describedby'
)
}
)
it.for([
{ label: 'desktop', mobile: false },
{ label: 'mobile', mobile: true }
])(
'does not show the workflow error warning in $label controls without an active subscription',
({ mobile }) => {
renderControls({
hasError: true,
isActiveSubscription: false,
mobile
})
expect(screen.queryByRole('status')).not.toBeInTheDocument()
}
)
it('does not show the warning when the error copy is empty', () => {
overlayMock.overlayMessage = ''
renderControls({ hasError: true })
expect(screen.queryByRole('status')).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Run' })).not.toHaveAttribute(
'aria-describedby'
)
})
})

View File

@@ -1,10 +1,11 @@
<script setup lang="ts">
import { useTimeout } from '@vueuse/core'
import { storeToRefs } from 'pinia'
import { ref, useTemplateRef } from 'vue'
import { computed, ref, toValue, useTemplateRef } from 'vue'
import { useI18n } from 'vue-i18n'
import AppModeWidgetList from '@/components/builder/AppModeWidgetList.vue'
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
import Loader from '@/components/loader/Loader.vue'
import ScrubableNumberInput from '@/components/common/ScrubableNumberInput.vue'
import Popover from '@/components/ui/Popover.vue'
@@ -14,11 +15,15 @@ import SubscribeToRunButton from '@/platform/cloud/subscription/components/Subsc
import { useSettingStore } from '@/platform/settings/settingStore'
import { useTelemetry } from '@/platform/telemetry'
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
import LinearRunErrorWarning from '@/renderer/extensions/linearMode/LinearRunErrorWarning.vue'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
import PartnerNodesList from '@/renderer/extensions/linearMode/PartnerNodesList.vue'
import { useCommandStore } from '@/stores/commandStore'
import { useQueueSettingsStore } from '@/stores/queueStore'
import { useAppMode } from '@/composables/useAppMode'
import { useAppModeStore } from '@/stores/appModeStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
const { t } = useI18n()
const commandStore = useCommandStore()
const { batchCount } = storeToRefs(useQueueSettingsStore())
@@ -28,6 +33,8 @@ const workflowStore = useWorkflowStore()
const { isBuilderMode } = useAppMode()
const appModeStore = useAppModeStore()
const { hasOutputs } = storeToRefs(appModeStore)
const { hasAnyError } = storeToRefs(useExecutionErrorStore())
const { overlayMessage } = useErrorOverlayState()
const { toastTo, mobile } = defineProps<{
toastTo?: string | HTMLElement
@@ -43,6 +50,13 @@ const { ready: jobToastTimeout, start: resetJobToastTimeout } = useTimeout(
{ controls: true, immediate: false }
)
const widgetListRef = useTemplateRef('widgetListRef')
const linearRunButtonTestId = 'linear-run-button'
const showRunErrorWarning = computed(
() =>
hasAnyError.value &&
toValue(isActiveSubscription) &&
toValue(overlayMessage).trim().length > 0
)
//TODO: refactor out of this file.
//code length is small, but changes should propagate
@@ -134,9 +148,10 @@ function handleDragDrop() {
<PartnerNodesList v-if="!mobile" />
<section
v-if="mobile"
data-testid="linear-run-button"
:data-testid="linearRunButtonTestId"
class="border-t border-node-component-border p-4 pb-6"
>
<LinearRunErrorWarning v-if="showRunErrorWarning" />
<SubscribeToRunButton
v-if="!isActiveSubscription"
class="mt-4 w-full"
@@ -166,18 +181,24 @@ function handleDragDrop() {
variant="primary"
class="grow"
size="lg"
:aria-describedby="
showRunErrorWarning
? LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
: undefined
"
@click="runButtonClick"
>
<i class="icon-[lucide--play]" />
<i aria-hidden="true" class="icon-[lucide--play]" />
{{ t('menu.run') }}
</Button>
</div>
</section>
<section
v-else
data-testid="linear-run-button"
:data-testid="linearRunButtonTestId"
class="border-t border-node-component-border p-4 pb-6"
>
<LinearRunErrorWarning v-if="showRunErrorWarning" />
<div
class="m-1 mb-2 text-node-component-slot-text"
v-text="t('linearMode.runCount')"
@@ -198,9 +219,14 @@ function handleDragDrop() {
variant="primary"
class="mt-4 w-full text-sm"
size="lg"
:aria-describedby="
showRunErrorWarning
? LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
: undefined
"
@click="runButtonClick"
>
<i class="icon-[lucide--play]" />
<i aria-hidden="true" class="icon-[lucide--play]" />
{{ t('menu.run') }}
</Button>
</section>

View File

@@ -0,0 +1,92 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { createI18n } from 'vue-i18n'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import LinearRunErrorWarning from '@/renderer/extensions/linearMode/LinearRunErrorWarning.vue'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
const mocks = vi.hoisted(() => ({
overlayMessage: 'KSampler is missing a required input: model',
overlayTitle: 'Required input missing',
viewErrorsInGraph: vi.fn()
}))
vi.mock('@/components/error/useErrorOverlayState', () => ({
useErrorOverlayState: () => ({
overlayMessage: mocks.overlayMessage,
overlayTitle: mocks.overlayTitle
})
}))
vi.mock('@/composables/useViewErrorsInGraph', () => ({
useViewErrorsInGraph: () => ({
viewErrorsInGraph: mocks.viewErrorsInGraph
})
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: {
en: {
linearMode: {
error: {
goto: 'Show errors in graph'
}
}
}
}
})
function renderWarning() {
const user = userEvent.setup()
const result = render(LinearRunErrorWarning, {
global: { plugins: [i18n] }
})
return { ...result, user }
}
describe('LinearRunErrorWarning', () => {
beforeEach(() => {
mocks.viewErrorsInGraph.mockReset()
})
it('shows the current error overlay title and message without a close action', () => {
renderWarning()
const warning = screen.getByRole('status')
expect(warning).toHaveTextContent('Required input missing')
expect(warning).toHaveTextContent(
'KSampler is missing a required input: model'
)
expect(screen.getByText('Required input missing')).toHaveAttribute(
'title',
'Required input missing'
)
const description = screen.getByTestId(
'linear-validation-warning-description'
)
expect(description).toHaveAttribute(
'id',
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
)
expect(description).toHaveTextContent('Required input missing')
expect(description).toHaveTextContent(
'KSampler is missing a required input: model'
)
expect(description).not.toHaveTextContent('Show errors in graph')
expect(screen.queryByLabelText('Close')).not.toBeInTheDocument()
})
it('opens graph errors when the action is clicked', async () => {
const { user } = renderWarning()
await user.click(
screen.getByRole('button', { name: 'Show errors in graph' })
)
expect(mocks.viewErrorsInGraph).toHaveBeenCalledOnce()
})
})

View File

@@ -0,0 +1,63 @@
<script setup lang="ts">
import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
const { t } = useI18n()
const { viewErrorsInGraph } = useViewErrorsInGraph()
const { overlayMessage, overlayTitle } = useErrorOverlayState()
</script>
<template>
<div
role="status"
data-testid="linear-validation-warning"
class="mb-3 flex w-full flex-col gap-2 overflow-hidden rounded-lg border border-l-4 border-border-default border-l-destructive-background bg-base-background p-3 shadow-interface transition-colors duration-200 ease-in-out"
>
<div
:id="LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID"
data-testid="linear-validation-warning-description"
class="flex flex-col gap-2"
>
<div class="flex w-full items-start gap-2">
<i
aria-hidden="true"
class="mt-0.5 icon-[lucide--circle-x] size-4 shrink-0 text-destructive-background"
/>
<span
class="min-w-0 flex-1 truncate text-sm text-base-foreground"
:title="overlayTitle"
>
{{ overlayTitle }}
</span>
</div>
<div
class="flex w-full items-start gap-2"
data-testid="linear-validation-warning-message"
>
<span class="size-4 shrink-0" aria-hidden="true" />
<p
class="m-0 line-clamp-3 min-w-0 flex-1 text-sm/snug wrap-break-word whitespace-pre-wrap text-muted-foreground"
>
{{ overlayMessage }}
</p>
</div>
</div>
<div class="flex w-full items-center justify-end pt-2">
<Button
variant="secondary"
size="unset"
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
data-testid="linear-view-errors"
@click="viewErrorsInGraph"
>
{{ t('linearMode.error.goto') }}
</Button>
</div>
</div>
</template>

View File

@@ -0,0 +1,2 @@
export const LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID =
'linear-run-error-warning'

View File

@@ -0,0 +1,57 @@
import { describe, expect, it } from 'vitest'
import {
getComboSpecComboOptions,
getInputSpecType,
isComboInputSpec,
isComboInputSpecV1,
isComboInputSpecV2,
isFloatInputSpec,
isIntInputSpec,
isMediaUploadComboInput
} from './nodeDefSchema'
import type {
ComboInputSpec,
ComboInputSpecV2,
InputSpec
} from './nodeDefSchema'
describe('node definition schema helpers', () => {
it('identifies input spec variants', () => {
const intSpec: InputSpec = ['INT', {}]
const floatSpec: InputSpec = ['FLOAT', {}]
const comboV1: ComboInputSpec = [['a', 'b'], {}]
const comboV2: ComboInputSpecV2 = ['COMBO', { options: ['a', 'b'] }]
expect(isIntInputSpec(intSpec)).toBe(true)
expect(isFloatInputSpec(floatSpec)).toBe(true)
expect(isComboInputSpecV1(comboV1)).toBe(true)
expect(isComboInputSpecV2(comboV2)).toBe(true)
expect(isComboInputSpec(comboV1)).toBe(true)
expect(isComboInputSpec(comboV2)).toBe(true)
expect(getInputSpecType(comboV1)).toBe('COMBO')
expect(getInputSpecType(intSpec)).toBe('INT')
})
it('reads combo options from legacy and v2 combo specs', () => {
expect(getComboSpecComboOptions([['a', 1], {}])).toEqual(['a', 1])
expect(
getComboSpecComboOptions(['COMBO', { options: ['x', 'y'] }])
).toEqual(['x', 'y'])
expect(getComboSpecComboOptions(['COMBO', {}])).toEqual([])
})
it('detects media upload combo inputs', () => {
expect(isMediaUploadComboInput([['a'], { image_upload: true }])).toBe(true)
expect(
isMediaUploadComboInput(['COMBO', { animated_image_upload: true }])
).toBe(true)
expect(isMediaUploadComboInput(['COMBO', { video_upload: true }])).toBe(
true
)
expect(isMediaUploadComboInput(['STRING', { image_upload: true }])).toBe(
false
)
expect(isMediaUploadComboInput(['COMBO', undefined])).toBe(false)
})
})

View File

@@ -0,0 +1,149 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
fetchWithUnifiedRemint,
shouldRemintCloudRequest
} from '@/platform/auth/unified/remintRetry'
const { mockAuthStore } = vi.hoisted(() => ({
mockAuthStore: {
isInitialized: true,
getAuthHeader: vi.fn(),
getAuthToken: vi.fn()
}
}))
vi.mock('@/platform/distribution/types', () => ({ isCloud: true }))
vi.mock('@/stores/authStore', () => ({
useAuthStore: vi.fn(() => mockAuthStore)
}))
vi.mock('@/platform/auth/unified/remintRetry', () => ({
fetchWithUnifiedRemint: vi.fn(),
shouldRemintCloudRequest: vi.fn()
}))
class FakeWebSocket extends EventTarget {
static instances: FakeWebSocket[] = []
binaryType = ''
sent: string[] = []
constructor(readonly url: string) {
super()
FakeWebSocket.instances.push(this)
}
send(data: string) {
this.sent.push(data)
}
close() {
this.dispatchEvent(new Event('close'))
}
}
const { ComfyApi } = await import('./api')
describe('ComfyApi cloud mode', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.unstubAllGlobals()
FakeWebSocket.instances = []
window.name = ''
sessionStorage.clear()
mockAuthStore.isInitialized = true
mockAuthStore.getAuthHeader.mockResolvedValue(null)
mockAuthStore.getAuthToken.mockResolvedValue(null)
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(false)
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
new Response(JSON.stringify({ ok: true }), {
headers: { 'Content-Type': 'application/json' }
})
)
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
})
it('adds cloud auth headers and enables unified retry for authenticated requests', async () => {
mockAuthStore.getAuthHeader.mockResolvedValue({
Authorization: 'Bearer firebase-token'
})
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(true)
const api = new ComfyApi()
api.user = 'cloud-user'
await api.fetchApi('/queue')
expect(api.api_base).toBe('')
expect(fetchWithUnifiedRemint).toHaveBeenCalledWith(
'/api/queue',
expect.objectContaining({
cache: 'no-cache',
headers: {
Authorization: 'Bearer firebase-token',
'Comfy-User': 'cloud-user'
}
}),
true
)
})
it('continues cloud fetches when auth header lookup fails', async () => {
mockAuthStore.getAuthHeader.mockRejectedValue(new Error('auth unavailable'))
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
const api = new ComfyApi()
await api.fetchApi('/history', {
headers: [['X-Test', '1']]
})
const [, options, retryOn401] = vi.mocked(fetchWithUnifiedRemint).mock
.calls[0]
expect(options.headers).toEqual([
['X-Test', '1'],
['Comfy-User', '']
])
expect(retryOn401).toBe(false)
expect(shouldRemintCloudRequest).not.toHaveBeenCalled()
expect(warn).toHaveBeenCalledWith(
'Failed to get auth header:',
expect.any(Error)
)
})
it('adds the cloud auth token to websocket URLs', async () => {
mockAuthStore.getAuthToken.mockResolvedValue('socket-token')
window.name = 'client-1'
const api = new ComfyApi()
api.init()
await vi.waitFor(() => {
expect(FakeWebSocket.instances).toHaveLength(1)
})
const socket = FakeWebSocket.instances[0]
expect(socket.url).toContain('clientId=client-1')
expect(socket.url).toContain('token=socket-token')
})
it('opens a cloud websocket without a token when token lookup fails', async () => {
mockAuthStore.getAuthToken.mockRejectedValue(new Error('token unavailable'))
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
const api = new ComfyApi()
api.init()
await vi.waitFor(() => {
expect(FakeWebSocket.instances).toHaveLength(1)
})
const socket = FakeWebSocket.instances[0]
expect(socket.url).not.toContain('token=')
expect(warn).toHaveBeenCalledWith(
'Could not get auth token for WebSocket connection:',
expect.any(Error)
)
})
})

View File

@@ -0,0 +1,460 @@
import axios from 'axios'
import { fromPartial } from '@total-typescript/shoehorn'
import { afterEach, describe, expect, it, vi } from 'vitest'
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
import type {
ComfyApiWorkflow,
ComfyWorkflowJSON
} from '@/platform/workflow/validation/schemas/workflowSchema'
import type { PromptResponse } from '@/schemas/apiSchema'
import { api as sharedApi, ComfyApi, PromptExecutionError } from '@/scripts/api'
import type { NodeExecutionId } from '@/types/nodeIdentification'
const fetchJobs = vi.hoisted(() => ({
fetchHistory: vi.fn(),
fetchJobDetail: vi.fn(),
fetchQueue: vi.fn()
}))
vi.mock('axios')
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => fetchJobs)
afterEach(() => {
vi.restoreAllMocks()
vi.clearAllMocks()
})
function jsonResponse(data: unknown, init: ResponseInit = {}) {
return new Response(JSON.stringify(data), {
status: 200,
headers: { 'content-type': 'application/json' },
...init
})
}
function promptData(): {
output: ComfyApiWorkflow
workflow: ComfyWorkflowJSON
} {
return {
output: fromPartial<ComfyApiWorkflow>({
1: {
inputs: {},
class_type: 'KSampler',
_meta: { title: 'KSampler' }
}
}),
workflow: fromPartial<ComfyWorkflowJSON>({
version: 0.4,
nodes: [],
links: []
})
}
}
function requestBody(fetchApi: ReturnType<typeof vi.spyOn>, call = 0) {
const init = fetchApi.mock.calls[call][1]
return JSON.parse(String(init?.body)) as Record<string, unknown>
}
describe('PromptExecutionError', () => {
it('formats string and node-specific prompt errors', () => {
const response = fromPartial<PromptResponse>({
error: 'invalid prompt',
node_errors: {
7: {
class_type: 'KSampler',
dependent_outputs: [],
errors: [{ message: 'bad seed', details: 'seed must be numeric' }]
}
}
})
expect(new PromptExecutionError(response, 400).toString()).toBe(
'invalid prompt\nKSampler:\n - bad seed: seed must be numeric'
)
})
it('formats structured prompt errors without node errors', () => {
const response = fromPartial<PromptResponse>({
error: {
type: 'prompt_outputs_failed_validation',
message: 'Validation failed',
details: 'missing node'
}
})
expect(new PromptExecutionError(response).toString()).toBe(
'Validation failed: missing node'
)
})
})
describe('ComfyApi queuePrompt', () => {
it('sends front queue requests with auth and execution options', async () => {
const api = new ComfyApi()
api.clientId = 'client-1'
api.authToken = 'auth-token'
api.apiKey = 'api-key'
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
await api.queuePrompt(-1, promptData(), {
partialExecutionTargets: ['9:10' as NodeExecutionId],
previewMethod: 'auto'
})
expect(fetchApi).toHaveBeenCalledWith('/prompt', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: fetchApi.mock.calls[0][1]?.body
})
expect(typeof fetchApi.mock.calls[0][1]?.body).toBe('string')
expect(requestBody(fetchApi)).toMatchObject({
client_id: 'client-1',
front: true,
partial_execution_targets: ['9:10'],
extra_data: {
auth_token_comfy_org: 'auth-token',
api_key_comfy_org: 'api-key',
comfy_usage_source: 'comfyui-frontend',
preview_method: 'auto'
}
})
expect(requestBody(fetchApi)).not.toHaveProperty('number')
})
it('omits default-only queue options and sets explicit queue numbers', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockImplementation(() =>
Promise.resolve(jsonResponse({ prompt_id: 'queued' }))
)
await api.queuePrompt(0, promptData(), { previewMethod: 'default' })
await api.queuePrompt(4, promptData())
expect(requestBody(fetchApi, 0)).toMatchObject({ client_id: '' })
expect(requestBody(fetchApi, 0)).not.toHaveProperty('front')
expect(requestBody(fetchApi, 0)).not.toHaveProperty('number')
expect(
requestBody(fetchApi, 0).extra_data as Record<string, unknown>
).not.toHaveProperty('preview_method')
expect(requestBody(fetchApi, 1)).toMatchObject({ number: 4 })
})
it('throws parsed prompt errors from non-200 responses', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi').mockResolvedValue(
jsonResponse(
{
error: {
type: 'server_error',
message: 'Server rejected prompt',
details: 'bad output'
}
},
{ status: 400, statusText: 'Bad Request' }
)
)
await expect(api.queuePrompt(0, promptData())).rejects.toThrow(
'Prompt execution failed'
)
})
it('wraps non-json prompt errors with status details', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi').mockResolvedValue(
new Response('backend exploded', {
status: 500,
statusText: 'Server Error'
})
)
await expect(api.queuePrompt(0, promptData())).rejects.toMatchObject({
status: 500,
response: {
error: {
message: '500 Server Error',
details: 'backend exploded'
}
}
})
})
})
describe('ComfyApi read helpers', () => {
it('returns localized templates, default templates, and empty non-json responses', async () => {
const api = new ComfyApi()
vi.mocked(axios.get)
.mockResolvedValueOnce({
headers: { 'content-type': 'application/json' },
data: [{ name: 'localized' }]
})
.mockResolvedValueOnce({
headers: { 'content-type': 'text/html' },
data: '<html></html>'
})
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
{ name: 'localized' }
])
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
expect(vi.mocked(axios.get).mock.calls[0][0]).toContain(
'/templates/index.fr.json'
)
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
'/templates/index.json'
)
})
it('falls back from missing localized templates to the default index', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.mocked(axios.get)
.mockRejectedValueOnce(new Error('missing locale'))
.mockResolvedValueOnce({
headers: { 'content-type': 'application/json' },
data: [{ name: 'default' }]
})
await expect(api.getCoreWorkflowTemplates('ja')).resolves.toEqual([
{ name: 'default' }
])
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
'/templates/index.json'
)
})
it('returns empty model lists for 404s and filters internal folders', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(
jsonResponse([
{ name: 'checkpoints' },
{ name: 'configs' },
{ name: 'custom_nodes' }
])
)
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(jsonResponse(['model.safetensors']))
await expect(api.getModelFolders()).resolves.toEqual([])
await expect(api.getModelFolders()).resolves.toEqual([
{ name: 'checkpoints' }
])
await expect(api.getModels('checkpoints')).resolves.toEqual([])
await expect(api.getModels('checkpoints')).resolves.toEqual([
'model.safetensors'
])
expect(fetchApi).toHaveBeenCalledTimes(4)
})
it('handles model metadata text responses', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'error').mockImplementation(() => {})
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(new Response(''))
.mockResolvedValueOnce(new Response('{"format":"safetensors"}'))
.mockResolvedValueOnce(
new Response('not json', { status: 200, statusText: 'OK' })
)
await expect(
api.viewMetadata('checkpoints', 'a.safetensors')
).resolves.toBe(null)
await expect(
api.viewMetadata('checkpoints', 'a.safetensors')
).resolves.toEqual({ format: 'safetensors' })
await expect(
api.viewMetadata('checkpoints', 'a.safetensors')
).resolves.toBe(null)
})
it('gets fuse options only from json responses', async () => {
const api = new ComfyApi()
vi.mocked(axios.get)
.mockResolvedValueOnce({
headers: { 'content-type': 'application/json' },
data: { keys: ['name'] }
})
.mockResolvedValueOnce({
headers: { 'content-type': 'text/plain' },
data: 'nope'
})
vi.spyOn(console, 'error').mockImplementation(() => {})
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
await expect(api.getFuseOptions()).resolves.toBeNull()
vi.mocked(axios.get).mockRejectedValueOnce(new Error('missing'))
await expect(api.getFuseOptions()).resolves.toBeNull()
})
})
describe('ComfyApi queue and data helpers', () => {
it('routes item collection requests to queue or history', async () => {
const api = new ComfyApi()
const queue = vi.spyOn(api, 'getQueue').mockResolvedValue({
Running: [],
Pending: []
})
const historyItem = fromPartial<JobListItem>({
id: 'history-1',
status: 'completed',
create_time: 1,
priority: 0
})
const history = vi.spyOn(api, 'getHistory').mockResolvedValue([historyItem])
await expect(api.getItems('queue')).resolves.toEqual({
Running: [],
Pending: []
})
await expect(api.getItems('history')).resolves.toEqual([historyItem])
expect(queue).toHaveBeenCalledOnce()
expect(history).toHaveBeenCalledOnce()
})
it('returns queue fallbacks unless errors are requested', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'error').mockImplementation(() => {})
fetchJobs.fetchQueue.mockRejectedValue(new Error('network'))
await expect(api.getQueue()).resolves.toEqual({ Running: [], Pending: [] })
await expect(api.getQueue({ throwOnError: true })).rejects.toThrow(
'network'
)
})
it('returns empty history when fetchHistory fails', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'error').mockImplementation(() => {})
fetchJobs.fetchHistory.mockRejectedValue(new Error('history down'))
await expect(api.getHistory()).resolves.toEqual([])
})
it('posts item mutations with and without request bodies', async () => {
const api = new ComfyApi()
const fetchApi = vi.spyOn(api, 'fetchApi').mockResolvedValue(new Response())
await api.deleteItem('history', 'job-1')
await api.clearItems('queue')
await api.interrupt(null)
await api.interrupt('running-1')
expect(fetchApi.mock.calls.map((call) => call[0])).toEqual([
'/history',
'/queue',
'/interrupt',
'/interrupt'
])
expect(fetchApi.mock.calls[0][1]?.body).toBe(
JSON.stringify({ delete: ['job-1'] })
)
expect(fetchApi.mock.calls[1][1]?.body).toBe(
JSON.stringify({ clear: true })
)
expect(fetchApi.mock.calls[2][1]?.body).toBeUndefined()
expect(fetchApi.mock.calls[3][1]?.body).toBe(
JSON.stringify({ prompt_id: 'running-1' })
)
})
it('throws unauthorized settings responses', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi').mockResolvedValue(
new Response('', { status: 401, statusText: 'Unauthorized' })
)
await expect(api.getSettings()).rejects.toThrow('Unauthorized')
})
it('stores user data with default and raw-body options', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValue(new Response('', { status: 200 }))
const raw = new Blob(['raw'])
await api.storeUserData('a/b.json', { ok: true })
await api.storeUserData('raw.bin', raw, {
overwrite: false,
stringify: false,
throwOnError: false,
full_info: true
})
expect(fetchApi.mock.calls[0][0]).toBe(
'/userdata/a%2Fb.json?overwrite=true&full_info=false'
)
expect(fetchApi.mock.calls[0][1]?.body).toBe(JSON.stringify({ ok: true }))
expect(fetchApi.mock.calls[1][0]).toBe(
'/userdata/raw.bin?overwrite=false&full_info=true'
)
expect(fetchApi.mock.calls[1][1]?.body).toBe(raw)
})
it('honors storeUserData throwOnError', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(
new Response('', { status: 500, statusText: 'Server Error' })
)
.mockResolvedValueOnce(
new Response('', { status: 500, statusText: 'Server Error' })
)
await expect(api.storeUserData('bad.json', {})).rejects.toThrow(
"Error storing user data file 'bad.json': 500 Server Error"
)
await expect(
api.storeUserData('bad.json', {}, { throwOnError: false })
).resolves.toHaveProperty('status', 500)
})
it('lists full user data info by status', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(
new Response('', { status: 500, statusText: 'Server Error' })
)
.mockResolvedValueOnce(jsonResponse([{ path: 'x' }]))
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([])
await expect(api.listUserDataFullInfo('models/')).rejects.toThrow(
"Error getting user data list 'models': 500 Server Error"
)
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([
{ path: 'x' }
])
})
it('loads global subgraph records and deferred data', async () => {
const fetchApi = vi
.spyOn(sharedApi, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(
jsonResponse({
ready: { name: 'Ready', info: { node_pack: 'core' }, data: '{}' },
lazy: { name: 'Lazy', info: { node_pack: 'core' } }
})
)
.mockResolvedValueOnce(jsonResponse({ data: '{"lazy":true}' }))
await expect(sharedApi.getGlobalSubgraphs()).resolves.toEqual({})
const subgraphs = await sharedApi.getGlobalSubgraphs()
expect(subgraphs.ready.data).toBe('{}')
expect(subgraphs.lazy.data).toBeInstanceOf(Promise)
await expect(subgraphs.lazy.data).resolves.toBe('{"lazy":true}')
expect(fetchApi).toHaveBeenCalledTimes(3)
})
})

829
src/scripts/api.test.ts Normal file
View File

@@ -0,0 +1,829 @@
import { fromAny } from '@total-typescript/shoehorn'
import axios from 'axios'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import {
api as singletonApi,
ComfyApi,
PromptExecutionError,
UnauthorizedError
} from '@/scripts/api'
import type { ComfyApiWorkflow } from '@/platform/workflow/validation/schemas/workflowSchema'
import type { NodeExecutionId } from '@/types/nodeIdentification'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
const { mockToastStore } = vi.hoisted(() => ({
mockToastStore: {
add: vi.fn()
}
}))
vi.mock('@/platform/auth/unified/remintRetry', () => ({
fetchWithUnifiedRemint: vi.fn(),
shouldRemintCloudRequest: vi.fn()
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: vi.fn(() => mockToastStore)
}))
vi.mock('axios', () => ({
default: {
get: vi.fn(),
patch: vi.fn()
}
}))
class FakeWebSocket extends EventTarget {
static instances: FakeWebSocket[] = []
binaryType = ''
sent: string[] = []
constructor(readonly url: string) {
super()
FakeWebSocket.instances.push(this)
}
send(data: string) {
this.sent.push(data)
}
close() {
this.dispatchEvent(new Event('close'))
}
}
function jsonResponse(data: unknown, init: ResponseInit = {}) {
return new Response(JSON.stringify(data), {
status: 200,
headers: { 'Content-Type': 'application/json' },
...init
})
}
function createWorkflow() {
return {
last_node_id: 0,
last_link_id: 0,
nodes: [],
links: [],
groups: [],
config: {},
extra: {},
version: 0.4
}
}
function binaryMessage(type: number, payload: Uint8Array) {
const bytes = new Uint8Array(4 + payload.length)
new DataView(bytes.buffer).setUint32(0, type)
bytes.set(payload, 4)
return bytes.buffer
}
function uint32(value: number) {
const bytes = new Uint8Array(4)
new DataView(bytes.buffer).setUint32(0, value)
return bytes
}
function concatBytes(...chunks: Uint8Array[]) {
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
const result = new Uint8Array(totalLength)
let offset = 0
for (const chunk of chunks) {
result.set(chunk, offset)
offset += chunk.length
}
return result
}
describe('PromptExecutionError', () => {
it('formats string, structured, and node-level errors', () => {
expect(
new PromptExecutionError({
error: 'Queue rejected',
node_errors: {}
}).toString()
).toBe('Queue rejected')
expect(
new PromptExecutionError({
error: {
type: 'invalid_prompt',
message: 'Invalid prompt',
details: 'missing input'
},
node_errors: {
1: {
class_type: 'PreviewAny',
dependent_outputs: ['1'],
errors: [
{
type: 'required_input_missing',
message: 'Required input',
details: 'source'
}
]
}
}
}).toString()
).toContain('Invalid prompt: missing input\nPreviewAny:')
})
})
describe('ComfyApi', () => {
beforeEach(() => {
vi.clearAllMocks()
FakeWebSocket.instances = []
window.name = ''
sessionStorage.clear()
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
})
afterEach(() => {
vi.restoreAllMocks()
vi.unstubAllGlobals()
})
it('builds API, internal, file, and fetch URLs with user headers', async () => {
const api = new ComfyApi()
api.user = 'reviewer'
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
jsonResponse({ ok: true })
)
await api.fetchApi('/queue', {
headers: new Headers([['X-Test', '1']])
})
expect(api.apiURL('/api/custom')).toBe(`${api.api_base}/api/custom`)
expect(api.apiURL('/queue')).toBe(`${api.api_base}/api/queue`)
expect(api.internalURL('/logs')).toBe(`${api.api_base}/internal/logs`)
expect(api.fileURL('/view')).toBe(`${api.api_base}/view`)
const [, options] = vi.mocked(fetchWithUnifiedRemint).mock.calls[0]
expect(options.headers).toBeInstanceOf(Headers)
expect((options.headers as Headers).get('Comfy-User')).toBe('reviewer')
expect((options.headers as Headers).get('X-Test')).toBe('1')
})
it('guards event listeners and still allows removing them', async () => {
const api = new ComfyApi()
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
const listener = vi.fn()
const throwingListener = vi.fn(() => {
throw new Error('listener failed')
})
const asyncListener = vi.fn(() => Promise.reject(new Error('async failed')))
const objectListener = { handleEvent: vi.fn() }
api.addEventListener('status', null)
api.removeEventListener('status', null)
api.addEventListener('status', listener)
api.addEventListener('status', throwingListener)
api.addEventListener('status', asyncListener)
api.addEventListener('status', fromAny(objectListener))
api.dispatchCustomEvent('status', { exec_info: { queue_remaining: 1 } })
await Promise.resolve()
expect(listener).toHaveBeenCalled()
expect(throwingListener).toHaveBeenCalled()
expect(asyncListener).toHaveBeenCalled()
expect(objectListener.handleEvent).toHaveBeenCalled()
expect(warn).toHaveBeenCalledTimes(2)
api.removeEventListener('status', listener)
api.dispatchCustomEvent('status', null)
expect(listener).toHaveBeenCalledTimes(1)
})
it('reuses guarded listener wrappers and ignores unknown removals', () => {
const api = new ComfyApi()
const listener = vi.fn()
const neverRegistered = vi.fn()
api.addEventListener('status', listener)
api.addEventListener('status', listener)
api.removeEventListener('status', neverRegistered)
api.dispatchCustomEvent('status', null)
expect(listener).toHaveBeenCalledTimes(1)
expect(neverRegistered).not.toHaveBeenCalled()
})
it('supports guarded custom event listeners', () => {
const api = new ComfyApi()
const listener = vi.fn()
api.addCustomEventListener('custom-node-event', listener)
;(api as EventTarget).dispatchEvent(
new CustomEvent('custom-node-event', { detail: { ok: true } })
)
api.removeCustomEventListener('custom-node-event', listener)
;(api as EventTarget).dispatchEvent(
new CustomEvent('custom-node-event', { detail: { ok: false } })
)
expect(listener).toHaveBeenCalledTimes(1)
})
it('routes websocket JSON messages and custom registered messages', () => {
window.name = 'existing-client'
const api = new ComfyApi()
const status = vi.fn()
const executing = vi.fn()
const featureFlags = vi.fn()
const custom = vi.fn()
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
vi.spyOn(console, 'log').mockImplementation(() => undefined)
api.addEventListener('status', status)
api.addEventListener('executing', executing)
api.addEventListener('feature_flags', featureFlags)
api.addCustomEventListener('custom-message', custom)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(new Event('open'))
expect(socket.url).toContain('clientId=existing-client')
expect(JSON.parse(socket.sent[0])).toMatchObject({
type: 'feature_flags'
})
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'status',
data: {
sid: 'fresh-client',
status: { exec_info: { queue_remaining: 2 } }
}
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'executing',
data: { node: '12' }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'feature_flags',
data: { supports_progress_text_metadata: true }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'custom-message',
data: { from: 'extension' }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'unknown-message',
data: {}
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'unknown-message',
data: {}
})
})
)
expect(api.clientId).toBe('fresh-client')
expect(window.name).toBe('fresh-client')
expect(sessionStorage.getItem('clientId')).toBe('fresh-client')
expect(status).toHaveBeenCalledWith(
expect.objectContaining({
detail: { exec_info: { queue_remaining: 2 } }
})
)
expect(executing).toHaveBeenCalledWith(
expect.objectContaining({ detail: '12' })
)
expect(featureFlags).toHaveBeenCalled()
expect(api.serverSupportsFeature('supports_progress_text_metadata')).toBe(
true
)
expect(custom).toHaveBeenCalledWith(
expect.objectContaining({ detail: { from: 'extension' } })
)
expect(api.reportedUnknownMessageTypes.has('unknown-message')).toBe(true)
expect(warn).toHaveBeenCalledTimes(1)
})
it('polls status when the initial websocket connection fails', async () => {
vi.useFakeTimers()
const api = new ComfyApi()
const status = vi.fn()
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(
jsonResponse({ exec_info: { queue_remaining: 4 } })
)
.mockRejectedValueOnce(new Error('poll failed'))
api.addEventListener('status', status)
api.init()
FakeWebSocket.instances[0].dispatchEvent(new Event('error'))
await vi.advanceTimersByTimeAsync(1000)
await vi.advanceTimersByTimeAsync(1000)
expect(status).toHaveBeenCalledWith(
expect.objectContaining({
detail: { exec_info: { queue_remaining: 4 } }
})
)
expect(status).toHaveBeenCalledWith(
expect.objectContaining({ detail: null })
)
vi.useRealTimers()
})
it('emits reconnect lifecycle events after an opened websocket closes', async () => {
vi.useFakeTimers()
const api = new ComfyApi()
const status = vi.fn()
const reconnecting = vi.fn()
const reconnected = vi.fn()
api.addEventListener('status', status)
api.addEventListener('reconnecting', reconnecting)
api.addEventListener('reconnected', reconnected)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(new Event('open'))
socket.close()
expect(status).toHaveBeenCalledWith(
expect.objectContaining({ detail: null })
)
expect(reconnecting).toHaveBeenCalledOnce()
await vi.advanceTimersByTimeAsync(300)
const reconnectSocket = FakeWebSocket.instances[1]
reconnectSocket.dispatchEvent(new Event('open'))
expect(reconnected).toHaveBeenCalledOnce()
vi.useRealTimers()
})
it('routes websocket variants without session ids and display-node fallbacks', () => {
const api = new ComfyApi()
const status = vi.fn()
const executing = vi.fn()
api.addEventListener('status', status)
api.addEventListener('executing', executing)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'status',
data: { status: undefined }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'executing',
data: { node: 'real', display_node: 'display' }
})
})
)
expect(status).toHaveBeenCalledWith(
expect.objectContaining({ detail: null })
)
expect(executing).toHaveBeenCalledWith(
expect.objectContaining({ detail: 'display' })
)
})
it('routes binary preview and progress websocket messages', () => {
const api = new ComfyApi()
const preview = vi.fn()
const previewWithMetadata = vi.fn()
const progressText = vi.fn()
const encoder = new TextEncoder()
api.serverFeatureFlags.value = {
supports_progress_text_metadata: true
}
api.addEventListener('b_preview', preview)
api.addEventListener('b_preview_with_metadata', previewWithMetadata)
api.addEventListener('progress_text', progressText)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(1, concatBytes(uint32(2), new Uint8Array([1, 2])))
})
)
const promptId = encoder.encode('prompt-1')
const nodeId = encoder.encode('7')
const progressPayload = concatBytes(
uint32(promptId.length),
promptId,
uint32(nodeId.length),
nodeId,
encoder.encode('loading')
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(3, progressPayload)
})
)
const metadata = encoder.encode(
JSON.stringify({
image_type: 'image/webp',
node_id: '7',
display_node_id: '7',
parent_node_id: '4',
real_node_id: '7',
prompt_id: 'prompt-1'
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(
4,
concatBytes(uint32(metadata.length), metadata, new Uint8Array([9]))
)
})
)
expect(preview).toHaveBeenCalledWith(
expect.objectContaining({
detail: expect.objectContaining({ type: 'image/png' })
})
)
expect(progressText).toHaveBeenCalledWith(
expect.objectContaining({
detail: {
nodeId: '7',
text: 'loading',
prompt_id: 'prompt-1'
}
})
)
expect(previewWithMetadata).toHaveBeenCalledWith(
expect.objectContaining({
detail: expect.objectContaining({
nodeId: '7',
parentNodeId: '4',
jobId: 'prompt-1'
})
})
)
})
it('routes binary jpeg/default previews and malformed binary messages defensively', () => {
const api = new ComfyApi()
const preview = vi.fn()
const progressText = vi.fn()
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
const encoder = new TextEncoder()
api.addEventListener('b_preview', preview)
api.addEventListener('progress_text', progressText)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(1, concatBytes(uint32(1), new Uint8Array([1])))
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(1, concatBytes(uint32(99), new Uint8Array([2])))
})
)
const nodeId = encoder.encode('node')
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(
3,
concatBytes(uint32(nodeId.length), nodeId, encoder.encode('ready'))
)
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(3, new Uint8Array([1]))
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(99, new Uint8Array())
})
)
expect(preview).toHaveBeenCalledWith(
expect.objectContaining({
detail: expect.objectContaining({ type: 'image/jpeg' })
})
)
expect(progressText).toHaveBeenCalledWith(
expect.objectContaining({
detail: { nodeId: 'node', text: 'ready' }
})
)
expect(warn).toHaveBeenCalled()
})
it('serializes prompt queue options and surfaces non-200 errors', async () => {
const api = new ComfyApi()
api.clientId = 'client-1'
api.authToken = 'token-1'
api.apiKey = 'key-1'
const prompt: ComfyApiWorkflow = {
1: {
class_type: 'PreviewAny',
inputs: {},
_meta: { title: 'PreviewAny' }
}
}
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ prompt_id: 'queued' }))
.mockResolvedValueOnce(
new Response('backend exploded', {
status: 500,
statusText: 'Server Error'
})
)
await expect(
api.queuePrompt(
-1,
{ output: prompt, workflow: createWorkflow() },
{
partialExecutionTargets: ['7' as NodeExecutionId],
previewMethod: 'latent2rgb'
}
)
).resolves.toEqual({ prompt_id: 'queued' })
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
expect(body).toMatchObject({
client_id: 'client-1',
prompt,
partial_execution_targets: ['7'],
front: true,
extra_data: {
auth_token_comfy_org: 'token-1',
api_key_comfy_org: 'key-1',
comfy_usage_source: 'comfyui-frontend',
preview_method: 'latent2rgb'
}
})
expect(body.number).toBeUndefined()
await expect(
api.queuePrompt(3, { output: prompt, workflow: createWorkflow() })
).rejects.toMatchObject({ status: 500 })
})
it('omits queue position and default preview method for normal queueing', async () => {
const api = new ComfyApi()
const prompt: ComfyApiWorkflow = {
1: {
class_type: 'PreviewAny',
inputs: {},
_meta: { title: 'PreviewAny' }
}
}
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
await api.queuePrompt(
0,
{ output: prompt, workflow: createWorkflow() },
{
previewMethod: 'default'
}
)
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
expect(body.front).toBeUndefined()
expect(body.number).toBeUndefined()
expect(body.extra_data.preview_method).toBeUndefined()
})
it('handles shareable assets, settings, userdata, subgraphs, and memory APIs', async () => {
const api = new ComfyApi()
const prompt: ComfyApiWorkflow = {
1: {
class_type: 'PreviewAny',
inputs: {},
_meta: { title: 'PreviewAny' }
}
}
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ assets: [] }))
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(
new Response('', { status: 401, statusText: 'Unauthorized' })
)
.mockResolvedValueOnce(jsonResponse({}, { status: 204 }))
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(
jsonResponse({}, { status: 500, statusText: 'Server Error' })
)
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
vi.spyOn(singletonApi, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ data: 'subgraph-data' }))
.mockResolvedValueOnce(jsonResponse({ missing: {} }))
.mockResolvedValueOnce(jsonResponse({ one: { data: 'inline' } }))
await expect(
api.getShareableAssets(prompt, { owned: false })
).resolves.toEqual({ assets: [] })
await expect(api.getShareableAssets(prompt)).rejects.toThrow(
'Failed to fetch shareable assets'
)
await expect(api.getSettings()).rejects.toBeInstanceOf(UnauthorizedError)
await expect(
api.storeUserData('plain.txt', 'raw', {
overwrite: false,
stringify: false,
throwOnError: false,
full_info: true
})
).resolves.toHaveProperty('status', 204)
await expect(api.listUserDataFullInfo('/missing/')).resolves.toEqual([])
await expect(api.listUserDataFullInfo('/broken/')).rejects.toThrow(
"Error getting user data list '/broken'"
)
await expect(api.getGlobalSubgraphData('one')).resolves.toBe(
'subgraph-data'
)
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
"Global subgraph 'missing' returned empty data"
)
await expect(api.getGlobalSubgraphs()).resolves.toEqual({
one: { data: 'inline' }
})
await api.freeMemory({ freeExecutionCache: true })
await api.freeMemory({ freeExecutionCache: false })
await api.freeMemory({ freeExecutionCache: true })
expect(mockToastStore.add).toHaveBeenCalledWith(
expect.objectContaining({
summary: 'Models and Execution Cache have been cleared.'
})
)
expect(mockToastStore.add).toHaveBeenCalledWith(
expect.objectContaining({ summary: 'Models have been unloaded.' })
)
expect(mockToastStore.add).toHaveBeenCalledWith(
expect.objectContaining({
summary:
'Unloading of models failed. Installed ComfyUI may be an outdated version.'
})
)
})
it('rejects non-success global subgraph data responses', async () => {
const api = new ComfyApi()
vi.spyOn(singletonApi, 'fetchApi').mockResolvedValueOnce(
jsonResponse({}, { status: 404, statusText: 'Not Found' })
)
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
"Failed to fetch global subgraph 'missing': 404 Not Found"
)
})
it('handles successful settings and userdata helper request shapes', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ theme: 'dark' }))
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
await expect(api.getSettings()).resolves.toEqual({ theme: 'dark' })
await expect(api.storeUserData('bad.json', { a: 1 })).rejects.toThrow(
"Error storing user data file 'bad.json'"
)
await api.moveUserData('old/path.json', 'new path.json')
await api.deleteUserData('old/path.json')
expect(fetchApi.mock.calls[2]).toEqual([
'/userdata/old%2Fpath.json/move/new%20path.json?overwrite=false',
{ method: 'POST' }
])
expect(fetchApi.mock.calls[3]).toEqual([
'/userdata/old%2Fpath.json',
{ method: 'DELETE' }
])
})
it('handles global subgraph fallbacks and log endpoints', async () => {
const api = new ComfyApi()
vi.spyOn(singletonApi, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(
jsonResponse({
missing: {
name: 'Missing data',
info: { node_pack: 'core' }
}
})
)
.mockResolvedValueOnce(jsonResponse({ data: 'lazy-data' }))
vi.mocked(axios.get)
.mockResolvedValueOnce({ data: 'log text', headers: {} })
.mockResolvedValueOnce({ data: { logs: [] }, headers: {} })
.mockRejectedValueOnce(new Error('no folders'))
.mockResolvedValueOnce({
data: { checkpoints: ['/models'] },
headers: {}
})
vi.mocked(axios.patch).mockResolvedValue({ data: undefined })
await expect(api.getGlobalSubgraphs()).resolves.toEqual({})
const subgraphs = await api.getGlobalSubgraphs()
await expect(subgraphs.missing.data).resolves.toBe('lazy-data')
await expect(api.getLogs()).resolves.toBe('log text')
await expect(api.getRawLogs()).resolves.toEqual({ logs: [] })
await api.subscribeLogs(true)
await expect(api.getFolderPaths()).resolves.toEqual({})
await expect(api.getFolderPaths()).resolves.toEqual({
checkpoints: ['/models']
})
expect(axios.patch).toHaveBeenCalledWith(
api.internalURL('/logs/subscribe'),
{ enabled: true, clientId: undefined }
)
})
it('loads localized template indexes and fuse options defensively', async () => {
const api = new ComfyApi()
vi.mocked(axios.get)
.mockResolvedValueOnce({
data: [{ name: 'template' }],
headers: { 'content-type': 'application/json' }
})
.mockResolvedValueOnce({
data: '<html></html>',
headers: { 'content-type': 'text/html' }
})
.mockRejectedValueOnce(new Error('missing locale'))
.mockResolvedValueOnce({
data: [{ name: 'fallback' }],
headers: { 'content-type': 'application/json' }
})
.mockRejectedValueOnce(new Error('default missing'))
.mockResolvedValueOnce({
data: { keys: ['name'] },
headers: { 'content-type': 'application/json' }
})
.mockResolvedValueOnce({
data: '<html></html>',
headers: { 'content-type': 'text/html' }
})
.mockResolvedValueOnce({
data: { ignored: true },
headers: {}
})
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
vi.spyOn(console, 'error').mockImplementation(() => undefined)
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
{ name: 'template' }
])
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
await expect(api.getCoreWorkflowTemplates('zh')).resolves.toEqual([
{ name: 'fallback' }
])
await expect(api.getCoreWorkflowTemplates('en')).resolves.toEqual([])
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
await expect(api.getFuseOptions()).resolves.toBeNull()
await expect(api.getFuseOptions()).resolves.toBeNull()
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,12 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type {
CanvasPointerEvent,
Subgraph
} from '@/lib/litegraph/src/litegraph'
import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
import type { ExecutedWsMessage } from '@/schemas/apiSchema'
const mockAssert = vi.hoisted(() => vi.fn())
@@ -14,7 +20,7 @@ const mockNodeOutputStore = vi.hoisted(() => ({
}))
const mockSubgraphNavigationStore = vi.hoisted(() => ({
exportState: vi.fn(() => []),
exportState: vi.fn((): string[] => []),
restoreState: vi.fn()
}))
@@ -23,10 +29,24 @@ const mockWorkflowStore = vi.hoisted(() => ({
getWorkflowByPath: vi.fn()
}))
const mockExecutionStore = vi.hoisted(() => ({
queuedJobs: {} as Record<string, { workflow: { changeTracker: unknown } }>
}))
const mockMaskEditorIsOpened = vi.hoisted(() => vi.fn(() => false))
vi.mock('@/scripts/app', () => ({
app: {
constructor: {
maskeditor_is_opended: mockMaskEditorIsOpened
},
graph: {},
ui: {
autoQueueEnabled: false,
autoQueueMode: 'instant'
},
rootGraph: {
subgraphs: new Map(),
serialize: vi.fn(() => ({
nodes: [],
links: [],
@@ -39,8 +59,10 @@ vi.mock('@/scripts/app', () => ({
}))
},
canvas: {
ds: { scale: 1, offset: [0, 0] }
}
ds: { scale: 1, offset: [0, 0] },
setGraph: vi.fn()
},
loadGraphData: vi.fn()
}
}))
@@ -65,6 +87,10 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
useWorkflowStore: vi.fn(() => mockWorkflowStore)
}))
vi.mock('@/stores/executionStore', () => ({
useExecutionStore: vi.fn(() => mockExecutionStore)
}))
import { app } from '@/scripts/app'
import { api } from '@/scripts/api'
import { ChangeTracker } from '@/scripts/changeTracker'
@@ -107,13 +133,120 @@ function mockCanvasState(state: ComfyWorkflowJSON) {
vi.mocked(app.rootGraph.serialize).mockReturnValue(state as never)
}
type ListenerMap = Record<string, EventListener[]>
function storeListener(
listeners: ListenerMap,
type: string,
listener: EventListenerOrEventListenerObject
) {
if (typeof listener === 'function') {
listeners[type] ??= []
listeners[type].push(listener)
}
}
function dispatchStored(listeners: ListenerMap, type: string, event: Event) {
for (const listener of listeners[type] ?? []) {
listener(event)
}
}
async function flushAsyncFrame() {
await Promise.resolve()
await Promise.resolve()
}
function getApiListener(name: string) {
const call = vi
.mocked(api.addEventListener)
.mock.calls.find(([eventName]) => eventName === name)
expect(call).toBeDefined()
return call?.[1] as (event: CustomEvent<ExecutedWsMessage>) => void
}
describe('ChangeTracker', () => {
beforeEach(() => {
vi.clearAllMocks()
nodeIdCounter = 0
ChangeTracker.isLoadingGraph = false
Reflect.set(ChangeTracker, '_checkStateWarned', false)
mockWorkflowStore.activeWorkflow = null
mockWorkflowStore.getWorkflowByPath.mockReturnValue(null)
mockExecutionStore.queuedJobs = {}
mockMaskEditorIsOpened.mockReturnValue(false)
app.ui.autoQueueEnabled = false
app.ui.autoQueueMode = 'instant'
vi.mocked(app.canvas.setGraph).mockClear()
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
app.rootGraph.subgraphs.clear()
app.canvas.ds.scale = 1
app.canvas.ds.offset = [0, 0]
})
describe('reset', () => {
it('updates initialState from activeState or an explicit state', () => {
const tracker = createTracker(createState(1))
const changed = createState(2)
tracker.activeState = changed
tracker.reset()
expect(tracker.initialState).toEqual(changed)
expect(tracker.initialState).not.toBe(changed)
const explicit = createState(3)
tracker.reset(explicit)
expect(tracker.activeState).toEqual(explicit)
expect(tracker.activeState).not.toBe(explicit)
expect(tracker.initialState).toEqual(explicit)
})
it('does not reset while restoring state', () => {
const tracker = createTracker(createState(1))
const original = tracker.initialState
tracker._restoringState = true
tracker.reset(createState(2))
expect(tracker.initialState).toBe(original)
})
})
describe('restore', () => {
it('restores viewport, outputs, and root graph navigation', () => {
const tracker = createTracker()
app.canvas.ds.scale = 2
app.canvas.ds.offset = [10, 20]
mockNodeOutputStore.snapshotOutputs.mockReturnValue({ 1: { images: [] } })
mockSubgraphNavigationStore.exportState.mockReturnValue([])
tracker.store()
app.canvas.ds.scale = 1
app.canvas.ds.offset = [0, 0]
tracker.restore()
expect(app.canvas.ds.scale).toBe(2)
expect(app.canvas.ds.offset).toEqual([10, 20])
expect(mockNodeOutputStore.restoreOutputs).toHaveBeenCalledWith({
1: { images: [] }
})
expect(mockSubgraphNavigationStore.restoreState).toHaveBeenCalledWith([])
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
})
it('restores saved subgraph navigation when the subgraph exists', () => {
const tracker = createTracker()
const subgraph = { id: 'subgraph-1' } as unknown as Subgraph
app.rootGraph.subgraphs.set('subgraph-1', subgraph)
mockSubgraphNavigationStore.exportState.mockReturnValue(['subgraph-1'])
tracker.store()
tracker.restore()
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
})
})
describe('captureCanvasState', () => {
@@ -169,9 +302,32 @@ describe('ChangeTracker', () => {
expect.stringContaining('captureCanvasState')
)
})
it('reports inactive tracker calls only once for the same workflow', () => {
const tracker = createTracker()
tracker.workflow.path = '/test/dedupe-workflow.json'
mockWorkflowStore.activeWorkflow = { changeTracker: {} }
tracker.captureCanvasState()
tracker.captureCanvasState()
expect(mockAssert).toHaveBeenCalledOnce()
})
})
describe('state capture', () => {
it('sets the active state without pushing undo when none exists yet', () => {
const tracker = createTracker(createState(1))
const changed = createState(2)
tracker.activeState = undefined as never
mockCanvasState(changed)
tracker.captureCanvasState()
expect(tracker.activeState).toEqual(changed)
expect(tracker.undoQueue).toHaveLength(0)
})
it('pushes to undoQueue, updates activeState, and calls updateModified', () => {
const initial = createState(1)
const tracker = createTracker(initial)
@@ -238,6 +394,19 @@ describe('ChangeTracker', () => {
expect(tracker.undoQueue).toHaveLength(ChangeTracker.MAX_HISTORY)
})
it('does not capture until the outer change transaction finishes', () => {
const tracker = createTracker(createState(1))
tracker.beforeChange()
tracker.beforeChange()
mockCanvasState(createState(2))
tracker.afterChange()
expect(app.rootGraph.serialize).not.toHaveBeenCalled()
tracker.afterChange()
expect(app.rootGraph.serialize).toHaveBeenCalledOnce()
})
})
})
@@ -302,6 +471,105 @@ describe('ChangeTracker', () => {
})
})
describe('updateModified', () => {
it('updates workflow modified state when the store can find it', () => {
const state = createState(1)
const tracker = createTracker(state)
const workflow = { isModified: true }
mockWorkflowStore.getWorkflowByPath.mockReturnValue(workflow)
tracker.updateModified()
expect(workflow.isModified).toBe(false)
tracker.activeState = createState(2)
tracker.updateModified()
expect(workflow.isModified).toBe(true)
})
})
describe('undo and redo', () => {
it('restores previous state and moves the current state to the target queue', async () => {
const initial = createState(1)
const changed = createState(2)
const tracker = createTracker(changed)
tracker.undoQueue.push(initial)
await tracker.undo()
expect(app.loadGraphData).toHaveBeenCalledWith(
initial,
false,
false,
tracker.workflow,
{
checkForRerouteMigration: false,
silentAssetErrors: true
}
)
expect(tracker.activeState).toBe(initial)
expect(tracker.redoQueue).toEqual([changed])
expect(tracker._restoringState).toBe(false)
})
it('clears restoring state when loading fails', async () => {
const tracker = createTracker(createState(2))
tracker.undoQueue.push(createState(1))
vi.mocked(app.loadGraphData).mockRejectedValueOnce(
new Error('load failed')
)
await expect(tracker.undo()).rejects.toThrow('load failed')
expect(tracker._restoringState).toBe(false)
})
it('does nothing when no previous state exists', async () => {
const tracker = createTracker(createState(1))
await tracker.undo()
expect(app.loadGraphData).not.toHaveBeenCalled()
})
it('handles keyboard undo and redo shortcuts', async () => {
const tracker = createTracker()
const undo = vi.spyOn(tracker, 'undo').mockResolvedValue()
const redo = vi.spyOn(tracker, 'redo').mockResolvedValue()
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
)
).resolves.toBe(true)
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', {
key: 'z',
ctrlKey: true,
shiftKey: true
})
)
).resolves.toBe(true)
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', { key: 'y', metaKey: true })
)
).resolves.toBe(true)
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', {
key: 'z',
ctrlKey: true,
altKey: true
})
)
).resolves.toBeUndefined()
expect(undo).toHaveBeenCalledOnce()
expect(redo).toHaveBeenCalledTimes(2)
})
})
describe('checkState (deprecated)', () => {
it('delegates to captureCanvasState', () => {
const tracker = createTracker(createState(1))
@@ -312,5 +580,389 @@ describe('ChangeTracker', () => {
expect(tracker.activeState).toEqual(changed)
})
it('warns only once before delegating', () => {
const tracker = createTracker(createState(1))
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
tracker.checkState()
tracker.checkState()
expect(warn).toHaveBeenCalledOnce()
warn.mockRestore()
})
})
describe('bindInput', () => {
it('returns false for missing canvas or body elements', () => {
expect(ChangeTracker.bindInput(null)).toBe(false)
expect(ChangeTracker.bindInput(document.createElement('canvas'))).toBe(
false
)
expect(ChangeTracker.bindInput(document.body)).toBe(false)
})
it('captures state once when an input-like element changes', () => {
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
const input = document.createElement('input')
expect(ChangeTracker.bindInput(input)).toBe(true)
input.dispatchEvent(new Event('change'))
input.dispatchEvent(new Event('change'))
expect(capture).toHaveBeenCalledOnce()
})
it('binds textarea-like elements that expose an input handler slot', () => {
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
const element = document.createElement('div') as HTMLElement & {
oninput: unknown
}
element.oninput = null
expect(ChangeTracker.bindInput(element)).toBe(true)
element.dispatchEvent(new Event('change'))
expect(capture).toHaveBeenCalledOnce()
})
})
describe('init', () => {
it('captures changes from registered browser, graph, and API events', async () => {
const windowListeners: ListenerMap = {}
const documentListeners: ListenerMap = {}
const windowAddSpy = vi
.spyOn(window, 'addEventListener')
.mockImplementation((type, listener) => {
storeListener(windowListeners, type, listener)
})
const documentAddSpy = vi
.spyOn(document, 'addEventListener')
.mockImplementation((type, listener) => {
storeListener(documentListeners, type, listener)
})
const rafSpy = vi
.spyOn(window, 'requestAnimationFrame')
.mockImplementation((callback) => {
callback(0)
return 1
})
const processMouseUp = vi.fn(() => true)
const prompt = vi.fn()
const close = vi.fn(() => true)
const originalProcessMouseUp = LGraphCanvas.prototype.processMouseUp
const originalPrompt = LGraphCanvas.prototype.prompt
const originalClose = LiteGraph.ContextMenu.prototype.close
LGraphCanvas.prototype.processMouseUp = processMouseUp
LGraphCanvas.prototype.prompt = prompt
LiteGraph.ContextMenu.prototype.close = close
try {
ChangeTracker.init()
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
dispatchStored(windowListeners, 'mouseup', new MouseEvent('mouseup'))
getApiListener('promptQueued')(
new CustomEvent('promptQueued', {
detail: {} as unknown as ExecutedWsMessage
})
)
getApiListener('graphCleared')(
new CustomEvent('graphCleared', {
detail: {} as unknown as ExecutedWsMessage
})
)
dispatchStored(
documentListeners,
'litegraph:canvas',
new CustomEvent('litegraph:canvas', {
detail: { subType: 'before-change' }
})
)
dispatchStored(
documentListeners,
'litegraph:canvas',
new CustomEvent('litegraph:canvas', {
detail: { subType: 'after-change' }
})
)
expect(capture).toHaveBeenCalledTimes(4)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'Control' })
)
await flushAsyncFrame()
dispatchStored(windowListeners, 'keyup', new KeyboardEvent('keyup'))
expect(capture).toHaveBeenCalledTimes(5)
const undoRedo = vi.spyOn(tracker, 'undoRedo').mockResolvedValue(true)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
)
await flushAsyncFrame()
expect(undoRedo).toHaveBeenCalledOnce()
expect(capture).toHaveBeenCalledTimes(5)
undoRedo.mockResolvedValue(undefined)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'a' })
)
await flushAsyncFrame()
expect(capture).toHaveBeenCalledTimes(6)
const input = document.createElement('input')
document.body.append(input)
input.focus()
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'b' })
)
await flushAsyncFrame()
input.remove()
expect(capture).toHaveBeenCalledTimes(6)
mockMaskEditorIsOpened.mockReturnValue(true)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'c' })
)
await flushAsyncFrame()
expect(capture).toHaveBeenCalledTimes(6)
const canvas = {} as LGraphCanvas
LGraphCanvas.prototype.processMouseUp.call(
canvas,
new MouseEvent('mouseup') as CanvasPointerEvent
)
expect(processMouseUp).toHaveBeenCalledOnce()
expect(capture).toHaveBeenCalledTimes(7)
const promptCallback = vi.fn()
LGraphCanvas.prototype.prompt.call(
canvas,
'title',
'value',
promptCallback,
new MouseEvent('mouseup') as CanvasPointerEvent
)
const extendedCallback = prompt.mock.calls[0]?.[2] as
| ((value: string) => void)
| undefined
extendedCallback?.('updated')
expect(promptCallback).toHaveBeenCalledWith('updated')
expect(capture).toHaveBeenCalledTimes(8)
LiteGraph.ContextMenu.prototype.close.call(
{} as InstanceType<typeof LiteGraph.ContextMenu>,
new MouseEvent('mouseup')
)
expect(close).toHaveBeenCalledOnce()
expect(capture).toHaveBeenCalledTimes(9)
} finally {
LGraphCanvas.prototype.processMouseUp = originalProcessMouseUp
LGraphCanvas.prototype.prompt = originalPrompt
LiteGraph.ContextMenu.prototype.close = originalClose
windowAddSpy.mockRestore()
documentAddSpy.mockRestore()
rafSpy.mockRestore()
}
})
it('ignores repeat keydowns and missing active trackers', async () => {
const windowListeners: ListenerMap = {}
const windowAddSpy = vi
.spyOn(window, 'addEventListener')
.mockImplementation((type, listener) => {
storeListener(windowListeners, type, listener)
})
const documentAddSpy = vi
.spyOn(document, 'addEventListener')
.mockImplementation(() => undefined)
const rafSpy = vi
.spyOn(window, 'requestAnimationFrame')
.mockImplementation((callback) => {
callback(0)
return 1
})
try {
ChangeTracker.init()
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'x', repeat: true })
)
await flushAsyncFrame()
expect(capture).not.toHaveBeenCalled()
mockWorkflowStore.activeWorkflow = null
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'x' })
)
await flushAsyncFrame()
expect(capture).not.toHaveBeenCalled()
} finally {
windowAddSpy.mockRestore()
documentAddSpy.mockRestore()
rafSpy.mockRestore()
}
})
it('stores executed outputs for the workflow that owns the prompt', () => {
ChangeTracker.init()
const tracker = createTracker()
const executed = getApiListener('executed')
mockExecutionStore.queuedJobs = {
promptA: { workflow: { changeTracker: tracker } }
}
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
output: { images: ['first'] }
} as unknown as ExecutedWsMessage
})
)
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
merge: true,
output: { images: ['second'], text: ['caption'] }
} as unknown as ExecutedWsMessage
})
)
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'missing',
node: '2',
output: { images: ['ignored'] }
} as unknown as ExecutedWsMessage
})
)
expect(tracker.nodeOutputs).toEqual({
1: { images: ['first', 'second'], text: ['caption'] }
})
})
it('replaces non-array executed outputs during merge updates', () => {
ChangeTracker.init()
const tracker = createTracker()
const executed = getApiListener('executed')
mockExecutionStore.queuedJobs = {
promptA: { workflow: { changeTracker: tracker } }
}
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
output: { value: 'old' }
} as unknown as ExecutedWsMessage
})
)
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
merge: true,
output: { value: 'new' }
} as unknown as ExecutedWsMessage
})
)
expect(tracker.nodeOutputs).toEqual({
1: { value: 'new' }
})
})
})
describe('graphEqual', () => {
it('compares workflow nodes as an unordered set and ignores extra.ds', () => {
const first = createState(2)
const second = {
...createState(),
nodes: [...first.nodes].reverse(),
links: first.links,
groups: first.groups,
extra: { ds: { scale: 2 } }
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, first)).toBe(true)
expect(ChangeTracker.graphEqual(first, second)).toBe(true)
})
it('returns false for non-object values and meaningful graph differences', () => {
const first = createState(1)
const differentNodes = createState(2)
const differentLinks = {
...first,
links: [[1, 1, 0, 2, 0, 'MODEL']]
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, null as never)).toBe(false)
expect(ChangeTracker.graphEqual(first, differentNodes)).toBe(false)
expect(ChangeTracker.graphEqual(first, differentLinks)).toBe(false)
})
it('returns false for extra properties other than viewport state', () => {
const first = createState()
const second = {
...first,
extra: { custom: true }
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
})
it.each([
'floatingLinks',
'reroutes',
'groups',
'definitions',
'subgraphs'
] as const)('returns false when %s differs', (key) => {
const first = createState()
const second = {
...first,
[key]: [{ id: 1 }]
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
})
})
})

View File

@@ -1,11 +1,24 @@
import { describe, expect, test, vi } from 'vitest'
import type { LGraph } from '@/lib/litegraph/src/litegraph'
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
import { ComponentWidgetImpl, DOMWidgetImpl } from '@/scripts/domWidget'
import {
addWidget,
ComponentWidgetImpl,
DOMWidgetImpl,
isComponentWidget,
isDOMWidget
} from '@/scripts/domWidget'
const { registerWidget, unregisterWidget } = vi.hoisted(() => ({
registerWidget: vi.fn(),
unregisterWidget: vi.fn()
}))
vi.mock('@/stores/domWidgetStore', () => ({
useDomWidgetStore: () => ({
unregisterWidget: vi.fn()
registerWidget,
unregisterWidget
})
}))
@@ -24,13 +37,11 @@ describe('DOMWidget Y Position Preservation', () => {
options: {}
})
// Set a specific Y position
originalWidget.y = 66
const newNode = new LGraphNode('new-node')
const clonedWidget = originalWidget.createCopyForNode(newNode)
// Verify Y position is preserved
expect(clonedWidget.y).toBe(66)
expect(clonedWidget.node).toBe(newNode)
expect(clonedWidget.name).toBe('test-widget')
@@ -48,13 +59,11 @@ describe('DOMWidget Y Position Preservation', () => {
options: {}
})
// Set a specific Y position
originalWidget.y = 42
const newNode = new LGraphNode('new-node')
const clonedWidget = originalWidget.createCopyForNode(newNode)
// Verify Y position is preserved
expect(clonedWidget.y).toBe(42)
expect(clonedWidget.node).toBe(newNode)
expect(clonedWidget.element).toBe(mockElement)
@@ -71,11 +80,9 @@ describe('DOMWidget Y Position Preservation', () => {
options: {}
})
// Don't explicitly set Y (should be 0 by default)
const newNode = new LGraphNode('new-node')
const clonedWidget = originalWidget.createCopyForNode(newNode)
// Verify Y position is preserved (should be 0)
expect(clonedWidget.y).toBe(0)
})
})
@@ -96,3 +103,271 @@ describe('BaseDOMWidgetImpl.isVisible', () => {
expect(widget.isVisible()).toBe(false)
})
})
describe('DOMWidgetImpl', () => {
test('identifies DOM and component widgets', () => {
const node = new LGraphNode('test-node')
const domWidget = new DOMWidgetImpl({
node,
name: 'dom',
type: 'text',
element: document.createElement('textarea'),
options: {}
})
const componentWidget = new ComponentWidgetImpl({
node,
name: 'component',
component: { template: '<div />' },
inputSpec: { name: 'component', type: 'STRING' },
options: {}
})
expect(isDOMWidget(domWidget)).toBe(true)
expect(isDOMWidget(componentWidget)).toBe(false)
expect(isComponentWidget(componentWidget)).toBe(true)
expect(isComponentWidget(domWidget)).toBe(false)
})
test('uses option-backed values, callbacks, and margins', () => {
const node = new LGraphNode('test-node')
let value = 'initial'
const setValue = vi.fn((next: string) => {
value = next
})
const callback = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
getValue: () => value,
setValue,
margin: 4
}
})
widget.callback = callback
widget.value = 'next'
expect(widget.value).toBe('next')
expect(widget.margin).toBe(4)
expect(setValue).toHaveBeenCalledWith('next')
expect(callback).toHaveBeenCalledWith('next')
})
test('uses default value and margin when options do not provide them', () => {
const node = new LGraphNode('test-node')
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {}
})
expect(widget.value).toBe('')
expect(widget.margin).toBe(10)
})
test('draws zoom placeholders and delegates visible draws', () => {
const node = new LGraphNode('test-node')
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(true)
const onDraw = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
hideOnZoom: true,
margin: 5,
onDraw
}
})
const ctx = {
beginPath: vi.fn(),
fill: vi.fn(),
fillStyle: '#000',
rect: vi.fn()
} as unknown as CanvasRenderingContext2D
widget.draw(ctx, node, 100, 10, 40, true)
expect(ctx.rect).toHaveBeenCalledWith(5, 15, 90, 30)
expect(ctx.fill).toHaveBeenCalledOnce()
expect(ctx.fillStyle).toBe('#000')
expect(onDraw).toHaveBeenCalledWith(widget)
})
test('skips placeholder drawing when hidden', () => {
const node = new LGraphNode('test-node')
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(false)
const onDraw = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
hideOnZoom: true,
onDraw
}
})
const ctx = {
beginPath: vi.fn(),
fill: vi.fn(),
fillStyle: '#000',
rect: vi.fn()
} as unknown as CanvasRenderingContext2D
widget.draw(ctx, node, 100, 10, 40, true)
expect(ctx.rect).not.toHaveBeenCalled()
expect(onDraw).toHaveBeenCalledWith(widget)
})
test('computes hidden, option, percent, and fallback layout sizes', () => {
const node = new LGraphNode('test-node')
node.size = [100, 200]
const hiddenWidget = new DOMWidgetImpl({
node,
name: 'hidden',
type: 'hidden',
element: document.createElement('textarea'),
options: {}
})
const optionWidget = new DOMWidgetImpl({
node,
name: 'option',
type: 'text',
element: document.createElement('textarea'),
options: {
getMinHeight: () => 11,
getMaxHeight: () => 88,
getHeight: () => 44
}
})
const percentWidget = new DOMWidgetImpl({
node,
name: 'percent',
type: 'text',
element: document.createElement('textarea'),
options: {
getMinHeight: () => 10,
getMaxHeight: () => 60,
getHeight: () => '25%'
}
})
const fallbackWidget = new DOMWidgetImpl({
node,
name: 'fallback',
type: 'text',
element: document.createElement('textarea'),
options: {
getHeight: () => 40
}
})
expect(hiddenWidget.computeLayoutSize(node)).toEqual({
minHeight: 0,
maxHeight: 0,
minWidth: 0
})
expect(optionWidget.computeLayoutSize(node)).toEqual({
minHeight: 11,
maxHeight: 88,
minWidth: 0
})
expect(percentWidget.computeLayoutSize(node)).toEqual({
minHeight: 10,
maxHeight: 60,
minWidth: 0
})
expect(fallbackWidget.computeLayoutSize(node)).toEqual({
minHeight: 40,
maxHeight: undefined,
minWidth: 0
})
})
test('registers widgets immediately and through node lifecycle callbacks', () => {
registerWidget.mockClear()
unregisterWidget.mockClear()
const node = new LGraphNode('test-node')
node.graph = {} as LGraph
const beforeResize = vi.fn()
const afterResize = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
beforeResize,
afterResize
}
})
vi.spyOn(node, 'addCustomWidget')
addWidget(node, widget)
node.onAdded?.(node.graph)
node.onResize?.([0, 0])
node.onRemoved?.()
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
expect(registerWidget).toHaveBeenCalledWith(widget)
expect(registerWidget).toHaveBeenCalledTimes(2)
expect(beforeResize).toHaveBeenCalledWith(node)
expect(afterResize).toHaveBeenCalledWith(node)
expect(unregisterWidget).toHaveBeenCalledWith(widget.id)
})
test('computes component layout and serializes raw values', () => {
const node = new LGraphNode('test-node')
const value = { nested: true }
const widget = new ComponentWidgetImpl({
node,
name: 'component',
component: { template: '<div />' },
inputSpec: { name: 'component', type: 'STRING' },
options: {
getValue: () => value,
getMinHeight: () => 12,
getMaxHeight: () => 48
}
})
expect(widget.computeLayoutSize()).toEqual({
minHeight: 12,
maxHeight: 48,
minWidth: 0
})
expect(widget.serializeValue()).toEqual({ nested: true })
})
test('adds DOM widgets through LGraphNode prototype helper', () => {
const node = new LGraphNode('test-node')
const element = document.createElement('textarea')
let value = 'initial'
const setValue = vi.fn((next: string) => {
value = next
})
vi.spyOn(node, 'addCustomWidget')
const widget = node.addDOMWidget('text', 'textarea', element, {
getValue: () => value,
setValue
})
const callback = vi.fn()
widget.callback = callback
widget.value = 'next'
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
expect(widget.element).toBe(element)
expect(widget.options.hideOnZoom).toBe(true)
expect(widget.value).toBe('next')
expect(setValue).toHaveBeenCalledWith('next')
expect(callback).toHaveBeenCalledWith('next')
})
})

View File

@@ -0,0 +1,106 @@
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
import { describe, expect, it, vi } from 'vitest'
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
interface TestWidgetOptions {
name: string
type: string
default?: unknown
multiline?: boolean
}
function createWidgetFactory() {
return (node: LGraphNode, options: TestWidgetOptions): IBaseWidget => {
const widget: IBaseWidget = fromAny({
name: options.name,
type: options.type,
value: options.default,
options
})
node.widgets = [...(node.widgets ?? []), widget]
return widget
}
}
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
() => ({
useStringWidget: () => createWidgetFactory()
})
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
() => ({
useFloatWidget: () => createWidgetFactory()
})
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
() => ({
useBooleanWidget: () => createWidgetFactory()
})
)
import './errorNodeWidgets'
describe('errorNodeWidgets', () => {
it('restores widgets from serialized values on error nodes', () => {
const node = new LGraphNode('BrokenNode')
const longText = 'serialized value with more than twenty chars'
node.has_errors = true
node.onConfigure?.(
fromAny({
widgets_values: {
length: 5,
0: 'short text',
1: longText,
2: 12,
3: true,
4: { nested: 'value' }
}
})
)
expect(node.widgets).toHaveLength(5)
expect(node.widgets?.map((widget) => widget.name)).toEqual([
'UNKNOWN',
'UNKNOWN_1',
'UNKNOWN_2',
'UNKNOWN_3',
'UNKNOWN_4'
])
expect(node.widgets?.map((widget) => widget.label)).toEqual([
'UNKNOWN',
'UNKNOWN',
'UNKNOWN',
'UNKNOWN',
'UNKNOWN'
])
expect(node.widgets?.map((widget) => widget.value)).toEqual([
'short text',
longText,
12,
true,
'{"nested":"value"}'
])
expect(node.serialize_widgets).toBe(true)
})
it('leaves normal nodes unchanged', () => {
const node = new LGraphNode('HealthyNode')
node.onConfigure?.(
fromPartial({
widgets_values: ['ignored']
})
)
expect(node.widgets).toBeUndefined()
expect(node.serialize_widgets).toBeUndefined()
})
})

View File

@@ -7,7 +7,8 @@ import {
EXPECTED_PROMPT_NAN_COERCED,
EXPECTED_WORKFLOW,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromAvifFile } from './avif'
@@ -83,6 +84,11 @@ describe('AVIF metadata', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getFromAvifFile(file)).toEqual({})
})
it('resolves empty when the FileReader load has no result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getFromAvifFile(file)).toEqual({})
})
})
})
@@ -139,7 +145,8 @@ const buildInfeBox = (
itemType: string,
version = 2
): Uint8Array => {
const bodySize = 4 + 2 + 2 + 4 + 1 + 1
const itemIdSize = version === 2 ? 2 : version >= 3 ? 4 : 0
const bodySize = 4 + itemIdSize + (version >= 2 ? 2 + 4 + 1 + 1 : 0)
const totalSize = 8 + bodySize
const buf = new Uint8Array(totalSize)
const dv = new DataView(buf.buffer)
@@ -147,22 +154,36 @@ const buildInfeBox = (
buf.set(new TextEncoder().encode('infe'), 4)
buf[8] = version
if (version >= 2) {
setU16BE(dv, 12, itemId)
setU16BE(dv, 14, 0)
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), 16)
let p = 12
if (version === 2) {
setU16BE(dv, p, itemId)
p += 2
} else {
setU32BE(dv, p, itemId)
p += 4
}
setU16BE(dv, p, 0)
p += 2
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), p)
}
return buf
}
const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
const bodySize = 4 + 2 + infeBoxes.reduce((s, b) => s + b.length, 0)
const buildIinfBox = (infeBoxes: Uint8Array[], version = 0): Uint8Array => {
const countSize = version === 0 ? 2 : 4
const bodySize = 4 + countSize + infeBoxes.reduce((s, b) => s + b.length, 0)
const totalSize = 8 + bodySize
const buf = new Uint8Array(totalSize)
const dv = new DataView(buf.buffer)
setU32BE(dv, 0, totalSize)
buf.set(new TextEncoder().encode('iinf'), 4)
setU16BE(dv, 12, infeBoxes.length)
let off = 14
buf[8] = version
if (version === 0) {
setU16BE(dv, 12, infeBoxes.length)
} else {
setU32BE(dv, 12, infeBoxes.length)
}
let off = 8 + 4 + countSize
for (const ib of infeBoxes) {
buf.set(ib, off)
off += ib.length
@@ -170,31 +191,91 @@ const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
return buf
}
interface IlocItem {
itemId: number
extentOffset: number
extentLength: number
extents?: Array<{ extentOffset: number; extentLength: number }>
}
const buildIlocBox = (
items: { itemId: number; extentOffset: number; extentLength: number }[]
items: IlocItem[],
{
version = 0,
baseOffsetSize = 0,
indexSize = 0
}: { version?: number; baseOffsetSize?: number; indexSize?: number } = {}
): Uint8Array => {
const perItemSize = 2 + 2 + 0 + 2 + (4 + 4)
const bodySize = 4 + 1 + 1 + 2 + items.length * perItemSize
const itemCountSize = version < 2 ? 2 : 4
const itemIdSize = version < 2 ? 2 : 4
const constructionMethodSize = version === 1 || version === 2 ? 2 : 0
const itemSizes = items.map((item) => {
const extents = item.extents ?? [
{
extentOffset: item.extentOffset,
extentLength: item.extentLength
}
]
return (
itemIdSize +
constructionMethodSize +
2 +
baseOffsetSize +
2 +
extents.length * (indexSize + 4 + 4)
)
})
const bodySize =
4 + 1 + 1 + itemCountSize + itemSizes.reduce((sum, size) => sum + size, 0)
const totalSize = 8 + bodySize
const buf = new Uint8Array(totalSize)
const dv = new DataView(buf.buffer)
setU32BE(dv, 0, totalSize)
buf.set(new TextEncoder().encode('iloc'), 4)
buf[8] = version
buf[12] = 0x44
buf[13] = 0x00
setU16BE(dv, 14, items.length)
let p = 16
for (const it of items) {
setU16BE(dv, p, it.itemId)
buf[13] = (baseOffsetSize << 4) | indexSize
let p = 14
if (version < 2) {
setU16BE(dv, p, items.length)
p += 2
} else {
setU32BE(dv, p, items.length)
p += 4
}
for (const it of items) {
if (version < 2) {
setU16BE(dv, p, it.itemId)
p += 2
} else {
setU32BE(dv, p, it.itemId)
p += 4
}
if (version === 1 || version === 2) {
setU16BE(dv, p, 0)
p += 2
}
setU16BE(dv, p, 0)
p += 2
setU16BE(dv, p, 1)
if (baseOffsetSize > 0) {
setU32BE(dv, p, 0)
p += baseOffsetSize
}
const extents = it.extents ?? [
{
extentOffset: it.extentOffset,
extentLength: it.extentLength
}
]
setU16BE(dv, p, extents.length)
p += 2
setU32BE(dv, p, it.extentOffset)
p += 4
setU32BE(dv, p, it.extentLength)
p += 4
for (const extent of extents) {
p += indexSize
setU32BE(dv, p, extent.extentOffset)
p += 4
setU32BE(dv, p, extent.extentLength)
p += 4
}
}
return buf
}
@@ -231,7 +312,13 @@ interface BuildAvifOpts {
ftypBrand?: string
omitMeta?: boolean
omitIloc?: boolean
iinfVersion?: number
infeVersion?: number
ilocVersion?: number
ilocBaseOffsetSize?: number
ilocIndexSize?: number
ilocExtents?: Array<{ extentOffset: number; extentLength: number }>
rawExifData?: Uint8Array
}
const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
@@ -242,7 +329,13 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
ftypBrand = 'avif',
omitMeta = false,
omitIloc = false,
infeVersion = 2
iinfVersion = 0,
infeVersion = 2,
ilocVersion = 0,
ilocBaseOffsetSize = 0,
ilocIndexSize = 0,
ilocExtents,
rawExifData
} = opts
const ftyp = buildFtypBox(ftypBrand)
@@ -250,19 +343,43 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
return ftyp.slice().buffer as ArrayBuffer
}
const exifData = buildExifBlob(exifEntries, endian)
const exifData = rawExifData ?? buildExifBlob(exifEntries, endian)
const infe = buildInfeBox(1, itemType, infeVersion)
const iinf = buildIinfBox([infe])
const iinf = buildIinfBox([infe], iinfVersion)
const realIloc = buildIlocBox([
{ itemId: 1, extentOffset: 0, extentLength: exifData.length }
])
const realIloc = buildIlocBox(
[
{
itemId: 1,
extentOffset: 0,
extentLength: exifData.length,
extents: ilocExtents
}
],
{
version: ilocVersion,
baseOffsetSize: ilocBaseOffsetSize,
indexSize: ilocIndexSize
}
)
const metaSize = 8 + 4 + iinf.length + (omitIloc ? 0 : realIloc.length)
const exifOffset = ftyp.length + metaSize
const finalIloc = buildIlocBox([
{ itemId: 1, extentOffset: exifOffset, extentLength: exifData.length }
])
const finalIloc = buildIlocBox(
[
{
itemId: 1,
extentOffset: exifOffset,
extentLength: exifData.length,
extents: ilocExtents
}
],
{
version: ilocVersion,
baseOffsetSize: ilocBaseOffsetSize,
indexSize: ilocIndexSize
}
)
const finalInner = omitIloc ? [iinf] : [iinf, finalIloc]
const meta = buildMetaBox(finalInner)
@@ -319,6 +436,52 @@ describe('getFromAvifFile', () => {
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
})
it('extracts EXIF metadata from versioned item info and location boxes', async () => {
const workflow = '{"versioned":true}'
const file = fileFromBuffer(
buildAvifFile({
exifEntries: [`workflow:${workflow}`],
iinfVersion: 1,
infeVersion: 3,
ilocVersion: 2,
ilocBaseOffsetSize: 4,
ilocIndexSize: 4
})
)
const result = await getFromAvifFile(file)
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
})
it('returns {} when the Exif item has no extents', async () => {
const file = fileFromBuffer(
buildAvifFile({
exifEntries: ['workflow:{}'],
ilocExtents: []
})
)
const result = await getFromAvifFile(file)
expect(result).toEqual({})
})
it('returns {} when the Exif payload has no TIFF header', async () => {
const file = fileFromBuffer(
buildAvifFile({
rawExifData: new TextEncoder().encode('not tiff data')
})
)
const result = await getFromAvifFile(file)
expect(result).toEqual({})
expect(console.log).toHaveBeenCalledWith(
'Warning: TIFF header not found in EXIF data.'
)
})
it('returns {} when AVIF major brand is not "avif"', async () => {
const file = fileFromBuffer(
buildAvifFile({ exifEntries: ['workflow:{}'], ftypBrand: 'heic' })

View File

@@ -7,7 +7,8 @@ import {
EXPECTED_PROMPT_NAN_COERCED,
EXPECTED_WORKFLOW,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromWebmFile } from './ebml'
@@ -16,6 +17,56 @@ const nanFixturePath = path.resolve(
__dirname,
'__fixtures__/with_nan_metadata.webm'
)
const WEBM_SIGNATURE = new Uint8Array([0x1a, 0x45, 0xdf, 0xa3])
const SIMPLE_TAG = new Uint8Array([0x67, 0xc8])
const TAG_NAME = new Uint8Array([0x45, 0xa3])
const TAG_VALUE = new Uint8Array([0x44, 0x87])
const encoder = new TextEncoder()
function concatBytes(...chunks: Uint8Array[]) {
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
const result = new Uint8Array(totalLength)
let offset = 0
for (const chunk of chunks) {
result.set(chunk, offset)
offset += chunk.length
}
return result
}
function bytes(...values: number[]) {
return new Uint8Array(values)
}
function vint(value: number) {
return bytes(0x80 | value)
}
function element(id: Uint8Array, value: string) {
const encoded = encoder.encode(value)
return concatBytes(id, vint(encoded.length), encoded)
}
function simpleTag(name: string, value: string, useTwoByteSize = false) {
const payload = concatBytes(
element(TAG_NAME, name),
element(TAG_VALUE, value)
)
const size = useTwoByteSize
? bytes(0x40, payload.length)
: vint(payload.length)
return concatBytes(SIMPLE_TAG, size, payload)
}
async function readWebm(bytes: Uint8Array) {
return getFromWebmFile(
new File([bytes as Uint8Array<ArrayBuffer>], 'test.webm', {
type: 'video/webm'
})
)
}
describe('WebM/EBML metadata', () => {
it('extracts workflow and prompt from EBML SimpleTag elements', async () => {
@@ -46,6 +97,89 @@ describe('WebM/EBML metadata', () => {
expect(result).toEqual({})
})
it('extracts plain string tags and trims null-terminated values', async () => {
const result = await readWebm(
concatBytes(
WEBM_SIGNATURE,
simpleTag('\0Comment', ' hello \0ignored', true)
)
)
expect(result.comment).toBe('hello')
})
it('ignores prompt tags whose value is not complete JSON', async () => {
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', '{not-json'))
)
expect(result.prompt).toBeUndefined()
})
it('ignores prompt tags whose value has no JSON object', async () => {
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', 'not-json'))
)
expect(result.prompt).toBeUndefined()
})
it('parses the first complete JSON object from a prompt tag', async () => {
const result = await readWebm(
concatBytes(
WEBM_SIGNATURE,
simpleTag('PROMPT', 'prefix {"outer":{"inner":1}} trailing')
)
)
expect(result.prompt).toEqual({ outer: { inner: 1 } })
})
it('ignores tags whose name has no readable text', async () => {
const payload = concatBytes(
concatBytes(TAG_NAME, vint(2), bytes(0, 1)),
element(TAG_VALUE, 'value')
)
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
)
expect(result).toEqual({})
})
it('ignores tag elements with zero-sized names', async () => {
const payload = concatBytes(TAG_NAME, vint(0), element(TAG_VALUE, 'value'))
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
)
expect(result).toEqual({})
})
it('ignores malformed SimpleTag encodings', async () => {
const nameOnly = element(TAG_NAME, 'comment')
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG))
).resolves.toEqual({})
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x00)))
).resolves.toEqual({})
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x7f, 0xff)))
).resolves.toEqual({})
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(10), TAG_NAME))
).resolves.toEqual({})
await expect(
readWebm(
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(nameOnly.length), nameOnly)
)
).resolves.toEqual({})
})
describe('FileReader failure modes', () => {
afterEach(() => vi.restoreAllMocks())
@@ -60,5 +194,10 @@ describe('WebM/EBML metadata', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getFromWebmFile(file)).toEqual({})
})
it('resolves empty when the FileReader load has no result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getFromWebmFile(file)).toEqual({})
})
})
})

View File

@@ -1,3 +1,4 @@
import { fromAny } from '@total-typescript/shoehorn'
import { afterEach, describe, expect, it, vi } from 'vitest'
import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
@@ -5,7 +6,8 @@ import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
import {
EXPECTED_PROMPT_NAN_COERCED,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getGltfBinaryMetadata } from './gltf'
@@ -16,11 +18,15 @@ describe('GLTF binary metadata parser', () => {
return { header, headerView }
}
const createJSONChunk = (jsonData: ArrayBuffer) => {
const createJSONChunk = (
jsonData: ArrayBuffer,
chunkType = ASCII.JSON,
chunkLength = jsonData.byteLength
) => {
const chunkHeader = new ArrayBuffer(GltfSizeBytes.CHUNK_HEADER)
const chunkView = new DataView(chunkHeader)
chunkView.setUint32(0, jsonData.byteLength, true)
chunkView.setUint32(4, ASCII.JSON, true)
chunkView.setUint32(0, chunkLength, true)
chunkView.setUint32(4, chunkType, true)
return chunkHeader
}
@@ -52,13 +58,27 @@ describe('GLTF binary metadata parser', () => {
// Builds a GLB whose JSON chunk is the literal text passed in - used to
// embed Python generated bare NaN/Infinity tokens that JSON.stringify
// would otherwise coerce to null.
function createMockGltfFileFromText(jsonText: string): File {
interface MockGltfFileOptions {
chunkLength?: number
chunkType?: number
magicNumber?: number
}
function createMockGltfFileFromText(
jsonText: string,
{
chunkLength,
chunkType = ASCII.JSON,
magicNumber = ASCII.GLTF
}: MockGltfFileOptions = {}
): File {
const jsonData = new TextEncoder().encode(jsonText)
const { header, headerView } = createGLTFFileStructure()
setHeaders(headerView, jsonData.buffer)
setTypeHeader(headerView, magicNumber)
const chunkHeader = createJSONChunk(jsonData.buffer)
const chunkHeader = createJSONChunk(jsonData.buffer, chunkType, chunkLength)
const fileContent = new Uint8Array(
header.byteLength + chunkHeader.byteLength + jsonData.byteLength
@@ -130,7 +150,9 @@ describe('GLTF binary metadata parser', () => {
expect(metadata).toBeDefined()
expect(metadata.prompt).toBeDefined()
const prompt = metadata.prompt as Record<string, any>
const prompt: {
node1: { class_type: string; inputs: { seed: number } }
} = fromAny(metadata.prompt)
expect(prompt.node1.class_type).toBe('TestNode')
expect(prompt.node1.inputs.seed).toBe(123456)
})
@@ -179,6 +201,74 @@ describe('GLTF binary metadata parser', () => {
expect(metadata).toEqual({})
})
it('returns empty when the GLB magic number is invalid', async () => {
const mockFile = createMockGltfFileFromText('{}', { magicNumber: 0 })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the GLB is missing the first chunk header', async () => {
const { header, headerView } = createGLTFFileStructure()
setTypeHeader(headerView, ASCII.GLTF)
setVersionHeader(headerView, 2)
setTotalLengthHeader(headerView, GltfSizeBytes.HEADER)
const mockFile = new File([header], 'header-only.glb')
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the first chunk is not JSON', async () => {
const mockFile = createMockGltfFileFromText('{}', { chunkType: 0 })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the declared JSON chunk exceeds the buffer', async () => {
const mockFile = createMockGltfFileFromText('{}', { chunkLength: 1024 })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the JSON chunk cannot be parsed', async () => {
const mockFile = createMockGltfFileFromText('{not json')
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when asset extras are missing', async () => {
const mockFile = createMockGltfFile({ asset: { version: '2.0' } })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('skips string metadata values that are not valid JSON', async () => {
const mockFile = createMockGltfFile({
asset: {
version: '2.0',
extras: {
prompt: '{not json',
workflow: '{not json'
}
}
})
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
describe('FileReader failure modes', () => {
afterEach(() => vi.restoreAllMocks())
@@ -193,5 +283,15 @@ describe('GLTF binary metadata parser', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getGltfBinaryMetadata(file)).toEqual({})
})
it('resolves empty when the FileReader load result is missing', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getGltfBinaryMetadata(file)).toEqual({})
})
it('resolves empty when the FileReader result is not a buffer', async () => {
mockFileReaderResult('readAsArrayBuffer', 'not a buffer')
expect(await getGltfBinaryMetadata(file)).toEqual({})
})
})
})

View File

@@ -7,7 +7,8 @@ import {
EXPECTED_PROMPT_NAN_COERCED,
EXPECTED_WORKFLOW,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromIsobmffFile } from './isobmff'
@@ -16,6 +17,82 @@ const nanFixturePath = path.resolve(
__dirname,
'__fixtures__/with_nan_metadata.mp4'
)
const encoder = new TextEncoder()
function uint32(value: number) {
return new Uint8Array([
(value >>> 24) & 0xff,
(value >>> 16) & 0xff,
(value >>> 8) & 0xff,
value & 0xff
])
}
function concatBytes(...chunks: Uint8Array[]) {
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
const result = new Uint8Array(totalLength)
let offset = 0
for (const chunk of chunks) {
result.set(chunk, offset)
offset += chunk.length
}
return result
}
function box(type: string, payload = new Uint8Array(), size?: number) {
return concatBytes(
uint32(size ?? 8 + payload.length),
encoder.encode(type),
payload
)
}
function keyEntry(name: string) {
const encoded = encoder.encode(name)
return concatBytes(
uint32(8 + encoded.length),
encoder.encode('mdta'),
encoded
)
}
function keysBox(names: string[]) {
return box(
'keys',
concatBytes(uint32(0), uint32(names.length), ...names.map(keyEntry))
)
}
function dataBox(value: string | Uint8Array) {
const payload = typeof value === 'string' ? encoder.encode(value) : value
return box('data', concatBytes(uint32(0), uint32(0), payload))
}
function ilstItem(index: number, payload: Uint8Array) {
return concatBytes(uint32(8 + payload.length), uint32(index), payload)
}
function ilstBox(...items: Uint8Array[]) {
return box('ilst', concatBytes(...items))
}
function metaBox(...children: Uint8Array[]) {
return box('meta', concatBytes(uint32(0), ...children))
}
function udtaWithMeta(...children: Uint8Array[]) {
return box('udta', metaBox(...children))
}
async function readMp4(bytes: Uint8Array) {
return getFromIsobmffFile(
new File([bytes as Uint8Array<ArrayBuffer>], 'test.mp4', {
type: 'video/mp4'
})
)
}
describe('ISOBMFF (MP4) metadata', () => {
it('extracts workflow and prompt from QuickTime keys/ilst boxes', async () => {
@@ -48,6 +125,102 @@ describe('ISOBMFF (MP4) metadata', () => {
expect(result).toEqual({})
})
it('extracts metadata from udta nested inside moov', async () => {
const bytes = box(
'moov',
udtaWithMeta(
keysBox(['WORKFLOW']),
ilstBox(ilstItem(1, dataBox('xxxx{"nodes":[]}')))
)
)
const result = await readMp4(bytes)
expect(result.workflow).toEqual({ nodes: [] })
})
it('returns empty when a top-level box declares an impossible size', async () => {
const result = await readMp4(box('free', new Uint8Array([1, 2]), 100))
expect(result).toEqual({})
})
it('returns empty when the keys box cannot provide entries', async () => {
const tooShortKeys = box('keys', uint32(0))
const missingKeyEntry = box(
'keys',
concatBytes(uint32(0), uint32(1), uint32(8))
)
const malformedKey = box(
'keys',
concatBytes(uint32(0), uint32(1), uint32(7), encoder.encode('bad!'))
)
const oversizedKey = box(
'keys',
concatBytes(uint32(0), uint32(1), uint32(100), encoder.encode('bad!'))
)
await expect(readMp4(udtaWithMeta(tooShortKeys))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(missingKeyEntry))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(malformedKey))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(oversizedKey))).resolves.toEqual({})
})
it('ignores item entries whose key is unknown or unsupported', async () => {
const unknownIndex = udtaWithMeta(
keysBox(['PROMPT']),
ilstBox(ilstItem(2, dataBox('{"1":{}}')))
)
const unsupportedKey = udtaWithMeta(
keysBox(['DESCRIPTION']),
ilstBox(ilstItem(1, dataBox('{"ignored":true}')))
)
await expect(readMp4(unknownIndex)).resolves.toEqual({})
await expect(readMp4(unsupportedKey)).resolves.toEqual({})
})
it('ignores metadata items without readable JSON data', async () => {
const shortDataBox = box('data')
const noJson = dataBox('not-json')
const invalidJson = dataBox('prefix {not-json')
const noDataBox = box('free', new Uint8Array([1]))
const invalidItems = ilstBox(
concatBytes(uint32(8), uint32(1)),
concatBytes(uint32(100), uint32(1), invalidJson)
)
await expect(
readMp4(
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, shortDataBox)))
)
).resolves.toEqual({})
await expect(
readMp4(udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noJson))))
).resolves.toEqual({})
await expect(
readMp4(
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, invalidJson)))
)
).resolves.toEqual({})
await expect(
readMp4(
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noDataBox)))
)
).resolves.toEqual({})
await expect(
readMp4(udtaWithMeta(keysBox(['PROMPT']), invalidItems))
).resolves.toEqual({})
})
it('returns empty when required metadata boxes are absent', async () => {
await expect(readMp4(box('udta', box('free')))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(ilstBox()))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(keysBox(['PROMPT'])))).resolves.toEqual(
{}
)
})
describe('FileReader failure modes', () => {
afterEach(() => vi.restoreAllMocks())
@@ -63,5 +236,10 @@ describe('ISOBMFF (MP4) metadata', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getFromIsobmffFile(file)).toEqual({})
})
it('resolves empty when the FileReader load has no result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getFromIsobmffFile(file)).toEqual({})
})
})
})

View File

@@ -0,0 +1,127 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { getFromWebmFile } from '@/scripts/metadata/ebml'
import { getGltfBinaryMetadata } from '@/scripts/metadata/gltf'
import { getFromIsobmffFile } from '@/scripts/metadata/isobmff'
import { getDataFromJSON } from '@/scripts/metadata/json'
import { getMp3Metadata } from '@/scripts/metadata/mp3'
import { getOggMetadata } from '@/scripts/metadata/ogg'
import { getWorkflowDataFromFile } from '@/scripts/metadata/parser'
import { getSvgMetadata } from '@/scripts/metadata/svg'
import {
getAvifMetadata,
getFlacMetadata,
getLatentMetadata,
getPngMetadata,
getWebpMetadata
} from '@/scripts/pnginfo'
vi.mock('@/scripts/metadata/ebml', () => ({ getFromWebmFile: vi.fn() }))
vi.mock('@/scripts/metadata/gltf', () => ({ getGltfBinaryMetadata: vi.fn() }))
vi.mock('@/scripts/metadata/isobmff', () => ({ getFromIsobmffFile: vi.fn() }))
vi.mock('@/scripts/metadata/json', () => ({ getDataFromJSON: vi.fn() }))
vi.mock('@/scripts/metadata/mp3', () => ({ getMp3Metadata: vi.fn() }))
vi.mock('@/scripts/metadata/ogg', () => ({ getOggMetadata: vi.fn() }))
vi.mock('@/scripts/metadata/svg', () => ({ getSvgMetadata: vi.fn() }))
vi.mock('@/scripts/pnginfo', () => ({
getAvifMetadata: vi.fn(),
getFlacMetadata: vi.fn(),
getLatentMetadata: vi.fn(),
getPngMetadata: vi.fn(),
getWebpMetadata: vi.fn()
}))
function file(type: string, name = 'file') {
return new File(['data'], name, { type })
}
beforeEach(() => {
vi.clearAllMocks()
})
describe('getWorkflowDataFromFile', () => {
it('routes png/avif/mp3/ogg/webm to their parsers and returns the result', async () => {
vi.mocked(getPngMetadata).mockResolvedValue({ a: 1 } as never)
expect(await getWorkflowDataFromFile(file('image/png'))).toEqual({ a: 1 })
expect(getPngMetadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('image/avif'))
expect(getAvifMetadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('audio/mpeg'))
expect(getMp3Metadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('audio/ogg'))
expect(getOggMetadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('video/webm'))
expect(getFromWebmFile).toHaveBeenCalled()
})
it('extracts workflow/prompt from webp, preferring lowercase keys', async () => {
vi.mocked(getWebpMetadata).mockResolvedValue({
workflow: 'wf',
prompt: 'pr'
} as never)
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
workflow: 'wf',
prompt: 'pr'
})
})
it('falls back to capitalized webp keys when lowercase are absent', async () => {
vi.mocked(getWebpMetadata).mockResolvedValue({
Workflow: 'WF',
Prompt: 'PR'
} as never)
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
workflow: 'WF',
prompt: 'PR'
})
})
it('handles both flac mime types and extracts workflow/prompt', async () => {
vi.mocked(getFlacMetadata).mockResolvedValue({ workflow: 'w' } as never)
expect(await getWorkflowDataFromFile(file('audio/flac'))).toEqual({
workflow: 'w',
prompt: undefined
})
expect(await getWorkflowDataFromFile(file('audio/x-flac'))).toEqual({
workflow: 'w',
prompt: undefined
})
})
it('routes isobmff by mime type and by file extension', async () => {
await getWorkflowDataFromFile(file('video/mp4'))
await getWorkflowDataFromFile(file('', 'clip.mov'))
await getWorkflowDataFromFile(file('', 'clip.m4v'))
expect(getFromIsobmffFile).toHaveBeenCalledTimes(3)
})
it('routes svg and gltf by mime type or extension', async () => {
await getWorkflowDataFromFile(file('image/svg+xml'))
await getWorkflowDataFromFile(file('', 'icon.svg'))
expect(getSvgMetadata).toHaveBeenCalledTimes(2)
await getWorkflowDataFromFile(file('model/gltf-binary'))
await getWorkflowDataFromFile(file('', 'model.glb'))
expect(getGltfBinaryMetadata).toHaveBeenCalledTimes(2)
})
it('routes latent/safetensors and json by extension or mime type', async () => {
await getWorkflowDataFromFile(file('', 'x.latent'))
await getWorkflowDataFromFile(file('', 'x.safetensors'))
expect(getLatentMetadata).toHaveBeenCalledTimes(2)
await getWorkflowDataFromFile(file('application/json'))
await getWorkflowDataFromFile(file('', 'x.json'))
expect(getDataFromJSON).toHaveBeenCalledTimes(2)
})
it('returns undefined for an unrecognized file', async () => {
expect(
await getWorkflowDataFromFile(file('application/zip', 'a.zip'))
).toBe(undefined)
})
})

View File

@@ -2,7 +2,8 @@ import { afterEach, describe, expect, it, vi } from 'vitest'
import {
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromPngBuffer, getFromPngFile } from './png'
@@ -191,6 +192,27 @@ describe('getFromPngBuffer', () => {
const result = await getFromPngBuffer(buffer)
expect(result['workflow']).toBe(workflow)
})
it('logs error and skips compressed iTXt with invalid deflate data', async () => {
vi.spyOn(console, 'error').mockImplementation(() => {})
const buffer = createPngWithChunk(
'iTXt',
'workflow',
new Uint8Array([1, 2, 3]),
{
compressionFlag: 1,
compressionMethod: 0
}
)
const result = await getFromPngBuffer(buffer)
expect(result['workflow']).toBeUndefined()
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to decompress iTXt chunk "workflow"'),
expect.anything()
)
})
})
describe('getFromPngFile', () => {
@@ -228,5 +250,12 @@ describe('getFromPngFile', () => {
mockFileReaderAbort('readAsArrayBuffer')
await expect(getFromPngFile(file)).rejects.toThrow('FileReader aborted')
})
it('rejects when the FileReader load has no ArrayBuffer result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
await expect(getFromPngFile(file)).rejects.toThrow(
'Failed to read file as ArrayBuffer'
)
})
})
})

View File

@@ -1,13 +1,19 @@
import fs from 'fs'
import path from 'path'
import { afterEach, describe, expect, it, vi } from 'vitest'
import { fromPartial } from '@total-typescript/shoehorn'
import type { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
import { api } from '@/scripts/api'
import { getFromAvifFile } from './metadata/avif'
import { getFromFlacFile } from './metadata/flac'
import { getFromPngFile } from './metadata/png'
import {
getAvifMetadata,
getFlacMetadata,
importA1111,
getLatentMetadata,
getPngMetadata,
getWebpMetadata
@@ -56,6 +62,20 @@ function encodeAsciiIfd(entries: AsciiIfdEntry[]): Uint8Array {
return buf
}
function encodeNonAsciiIfdEntry(tag: number): Uint8Array {
const buf = new Uint8Array(22)
const dv = new DataView(buf.buffer)
buf.set([0x49, 0x49], 0)
dv.setUint16(2, 0x002a, true)
dv.setUint32(4, 8, true)
dv.setUint16(8, 1, true)
dv.setUint16(10, tag, true)
dv.setUint16(12, 3, true)
dv.setUint32(14, 1, true)
dv.setUint32(18, 123, true)
return buf
}
type WebpChunk = { type: string; payload: Uint8Array }
function wrapInWebp(chunks: WebpChunk[]): File {
@@ -157,6 +177,16 @@ describe('getWebpMetadata', () => {
expect(metadata).toEqual({ workflow: '{"a":1}' })
})
it('ignores EXIF entries that are not ASCII strings', async () => {
const file = wrapInWebp([
{ type: 'EXIF', payload: encodeNonAsciiIfdEntry(270) }
])
const metadata = await getWebpMetadata(file)
expect(metadata).toEqual({})
})
})
describe('getLatentMetadata', () => {
@@ -234,3 +264,313 @@ describe('format-specific metadata wrappers', () => {
expect(result).toEqual({ workflow: '{"avif":1}' })
})
})
describe('importA1111', () => {
function widget(
name: string,
options: string[] = []
): IBaseWidget & { value?: string | number } {
return fromPartial<IBaseWidget & { value?: string | number }>({
name,
options: { values: options },
value: undefined
})
}
function createNode(type: string): LGraphNode {
const widgetsByType: Record<string, IBaseWidget[]> = {
CheckpointLoaderSimple: [widget('ckpt_name', ['sd15.safetensors'])],
CLIPSetLastLayer: [widget('stop_at_clip_layer')],
CLIPTextEncode: [widget('text')],
EmptyLatentImage: [widget('width'), widget('height')],
ImageScale: [widget('width'), widget('height')],
ImageUpscaleWithModel: [],
KSampler: [
widget('cfg'),
widget('sampler_name', ['euler_a', 'dpmpp_2m']),
widget('scheduler', ['normal', 'karras']),
widget('seed'),
widget('steps'),
widget('denoise')
],
LatentUpscale: [
widget('upscale_method', ['nearest-exact']),
widget('width'),
widget('height')
],
LoraLoader: [
widget('lora_name', ['foo.safetensors']),
widget('strength_model'),
widget('strength_clip')
],
UpscaleModelLoader: [widget('model_name', ['ESRGAN'])],
VAEEncodeTiled: [],
VAEDecodeTiled: [],
SaveImage: [],
VAEDecode: []
}
return {
type,
widgets: widgetsByType[type] ?? [],
connect: vi.fn()
} as unknown as LGraphNode
}
function createGraph(): LGraph {
return fromPartial<LGraph>({
add: vi.fn(),
arrange: vi.fn(),
clear: vi.fn()
})
}
function findWidget(node: LGraphNode, name: string) {
return node.widgets?.find((widget) => widget.name === name)
}
it('ignores text without parsed generation settings', async () => {
const graph = createGraph()
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
await importA1111(graph, 'positive prompt only')
await importA1111(graph, 'positive prompt\nSteps:\n')
expect(graph.clear).not.toHaveBeenCalled()
})
it('ignores text without a negative prompt section', async () => {
const graph = createGraph()
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
await importA1111(
graph,
['positive prompt only', 'Steps: 20, Sampler: Euler a'].join('\n')
)
expect(graph.clear).not.toHaveBeenCalled()
})
it('stops when a required base node cannot be created', async () => {
const graph = createGraph()
vi.spyOn(console, 'error').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) =>
type === 'KSampler' ? null : createNode(type)
)
await importA1111(
graph,
[
'prompt',
'Negative prompt: blurry',
'Steps: 20, Sampler: Euler a, Size: 512x512'
].join('\n')
)
expect(graph.clear).not.toHaveBeenCalled()
expect(console.error).toHaveBeenCalledWith(
'Failed to create required nodes for A1111 import'
)
})
it('builds a basic graph from A1111 parameters', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue(['easynegative'])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'<lora:foo:0.7> portrait easynegative',
'Negative prompt: blurry <lora:bad:not-number>',
'Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 42, Size: 512x512, Model: sd15, Clip skip: 2, Model hash: ignored'
].join('\n')
)
const checkpoint = nodes.find(
(node) => node.type === 'CheckpointLoaderSimple'
)
const clipSkip = nodes.find((node) => node.type === 'CLIPSetLastLayer')
const sampler = nodes.find((node) => node.type === 'KSampler')
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
const lora = nodes.find((node) => node.type === 'LoraLoader')
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
expect(graph.clear).toHaveBeenCalledOnce()
expect(graph.arrange).toHaveBeenCalledOnce()
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('sd15.safetensors')
expect(findWidget(clipSkip!, 'stop_at_clip_layer')?.value).toBe(-2)
expect(findWidget(sampler!, 'cfg')?.value).toBe(7)
expect(findWidget(sampler!, 'sampler_name')?.value).toBe('euler_a')
expect(findWidget(sampler!, 'scheduler')?.value).toBe('normal')
expect(findWidget(sampler!, 'seed')?.value).toBe(42)
expect(findWidget(sampler!, 'steps')?.value).toBe(20)
expect(findWidget(image!, 'width')?.value).toBe(512)
expect(findWidget(image!, 'height')?.value).toBe(512)
expect(findWidget(lora!, 'lora_name')?.value).toBe('foo.safetensors')
expect(findWidget(lora!, 'strength_model')?.value).toBe(0.7)
expect(findWidget(lora!, 'strength_clip')?.value).toBe(0.7)
expect(findWidget(textNodes[0], 'text')?.value).toBe(
' portrait embedding:easynegative'
)
expect(findWidget(textNodes[1], 'text')?.value).toBe('blurry ')
})
it('keeps unknown option-prefix values and logs the mismatch', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 20, Sampler: Unknown Sampler, CFG scale: 7, Seed: 42, Size: 512x512, Model: unknown-model'
].join('\n')
)
const checkpoint = nodes.find(
(node) => node.type === 'CheckpointLoaderSimple'
)
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('unknown-model')
expect(console.warn).toHaveBeenCalledWith(
"Unknown value 'unknown-model' for widget 'ckpt_name'",
checkpoint
)
})
it('skips missing LoraLoader nodes while keeping prompt text cleaned', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
if (type === 'LoraLoader') return null
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'<lora:missing:0.5> portrait',
'Negative prompt: blurry',
'Steps: 20, Sampler: Euler a, Size: 512x512'
].join('\n')
)
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
expect(findWidget(textNodes[0], 'text')?.value).toBe(' portrait')
})
it('returns from latent hires setup when LatentUpscale cannot be created', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
if (type === 'LatentUpscale') return null
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 8, Sampler: Euler a, Size: 512x512, Hires upscale: 2, Hires upscaler: Latent'
].join('\n')
)
expect(nodes.some((node) => node.type === 'KSampler')).toBe(true)
expect(nodes.some((node) => node.type === 'LatentUpscale')).toBe(false)
})
it('builds a latent hires pass with explicit resize and denoise settings', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 12, Sampler: DPM++ 2M Karras, CFG scale: 5, Seed: 1, Size: 513x577, Model: sd15, Hires resize: 1025x1089, Hires steps: 4, Hires upscaler: Latent (nearest-exact), Denoising strength: 0.35'
].join('\n')
)
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
const latentUpscale = nodes.find((node) => node.type === 'LatentUpscale')
const samplers = nodes.filter((node) => node.type === 'KSampler')
expect(findWidget(image!, 'width')?.value).toBe(576)
expect(findWidget(image!, 'height')?.value).toBe(640)
expect(findWidget(latentUpscale!, 'upscale_method')?.value).toBe(
'nearest-exact'
)
expect(findWidget(latentUpscale!, 'width')?.value).toBe(1088)
expect(findWidget(latentUpscale!, 'height')?.value).toBe(1152)
expect(findWidget(samplers[0], 'scheduler')?.value).toBe('karras')
expect(findWidget(samplers[0], 'sampler_name')?.value).toBe('dpmpp_2m')
expect(findWidget(samplers[1], 'steps')?.value).toBe(4)
expect(findWidget(samplers[1], 'cfg')?.value).toBe(5)
expect(findWidget(samplers[1], 'scheduler')?.value).toBe('karras')
expect(findWidget(samplers[1], 'sampler_name')?.value).toBe('dpmpp_2m')
expect(findWidget(samplers[1], 'denoise')?.value).toBe(0.35)
})
it('builds an image upscaler hires pass with fallback steps and denoise', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 8, Sampler: Euler a, CFG scale: 6, Seed: 2, Size: 512x512, Model: sd15, Hires upscale: 1.5, Hires upscaler: ESRGAN'
].join('\n')
)
const upscaleLoader = nodes.find(
(node) => node.type === 'UpscaleModelLoader'
)
const imageScale = nodes.find((node) => node.type === 'ImageScale')
const samplers = nodes.filter((node) => node.type === 'KSampler')
expect(findWidget(upscaleLoader!, 'model_name')?.value).toBe('ESRGAN')
expect(findWidget(imageScale!, 'width')?.value).toBe(768)
expect(findWidget(imageScale!, 'height')?.value).toBe(768)
expect(findWidget(samplers[1], 'steps')?.value).toBe(8)
expect(findWidget(samplers[1], 'denoise')?.value).toBe(1)
})
})

450
src/scripts/ui.test.ts Normal file
View File

@@ -0,0 +1,450 @@
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { mockSettingStore } = vi.hoisted(() => ({
mockSettingStore: {
get: vi.fn()
}
}))
vi.mock('@/composables/useRunButtonTelemetry', () => ({
useRunButtonTelemetry: () => ({ trackRunButton: vi.fn() })
}))
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => ({
extractWorkflow: vi.fn()
}))
vi.mock('@/platform/settings/composables/useSettingsDialog', () => ({
useSettingsDialog: () => ({ show: vi.fn() })
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => mockSettingStore
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackWorkflowExecution: vi.fn() })
}))
vi.mock('@/services/litegraphService', () => ({
useLitegraphService: () => ({ resetView: vi.fn() })
}))
vi.mock('@/stores/commandStore', () => ({
useCommandStore: () => ({ execute: vi.fn() })
}))
vi.mock('@/stores/workspaceStore', () => ({
useWorkspaceStore: () => ({ focusMode: false })
}))
vi.mock('./api', () => ({
api: {
addEventListener: vi.fn(),
clearItems: vi.fn(),
deleteItem: vi.fn(),
dispatchCustomEvent: vi.fn(),
getHistory: vi.fn(),
getJobDetail: vi.fn(),
getQueue: vi.fn(),
interrupt: vi.fn()
}
}))
vi.mock('./app', () => ({
app: {
clean: vi.fn(),
handleFile: vi.fn(),
loadGraphData: vi.fn(),
openClipspace: vi.fn(),
queuePrompt: vi.fn(),
refreshComboInNodes: vi.fn(() => Promise.resolve())
},
ComfyApp: class ComfyApp {}
}))
vi.mock('./ui/dialog', () => ({
ComfyDialog: class ComfyDialog {}
}))
vi.mock('./ui/settings', () => ({
ComfySettingsDialog: class ComfySettingsDialog {}
}))
vi.mock('./ui/toggleSwitch', () => ({
toggleSwitch: () => document.createElement('div')
}))
import { extractWorkflow } from '@/platform/remote/comfyui/jobs/fetchJobs'
import { api } from './api'
import { app } from './app'
import { $el, ComfyUI } from './ui'
beforeEach(() => {
document.body.replaceChildren()
localStorage.clear()
vi.clearAllMocks()
mockSettingStore.get.mockReturnValue(undefined)
Object.assign(app, {
lastExecutionError: undefined,
nodeOutputs: undefined
})
})
async function click(button: HTMLButtonElement) {
const handler = button.onclick
expect(handler).toBeTypeOf('function')
await Promise.resolve(handler?.call(button, fromAny(new MouseEvent('click'))))
}
function buttonByText(root: ParentNode, text: string): HTMLButtonElement {
const button = [...root.querySelectorAll('button')].find(
(candidate) => candidate.textContent === text
)
if (!button) throw new Error(`Missing button: ${text}`)
return button
}
describe('$el', () => {
it('creates elements with classes, children, props, and callbacks', () => {
const parent = document.createElement('section')
const child = document.createElement('span')
const callback = vi.fn()
const element = $el(
'label.primary.secondary',
{
parent,
$: callback,
dataset: { role: 'name' },
for: 'target-input',
style: { display: 'block' },
title: 'Label'
},
child
)
expect(element.tagName).toBe('LABEL')
expect(element.classList.contains('primary')).toBe(true)
expect(element.classList.contains('secondary')).toBe(true)
expect(element.dataset.role).toBe('name')
expect(element.getAttribute('for')).toBe('target-input')
expect(element.style.display).toBe('block')
expect(element.title).toBe('Label')
expect(element.firstElementChild).toBe(child)
expect(parent.firstElementChild).toBe(element)
expect(callback).toHaveBeenCalledWith(element)
})
it('accepts string and single-element children shorthands', () => {
const textElement = $el('button', 'Run')
const child = document.createElement('strong')
const wrapper = $el('div', child)
expect(textElement.textContent).toBe('Run')
expect(wrapper.firstElementChild).toBe(child)
})
})
describe('ComfyUI legacy menu', () => {
it('loads queue items and runs list actions', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [{ id: 'running', priority: 1 }],
Pending: [{ id: 'pending', priority: 2 }]
} as never)
vi.mocked(api.getJobDetail).mockResolvedValue({
outputs: { node: { images: ['image.png'] } }
} as never)
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
const ui = new ComfyUI(app)
await ui.queue.show()
await click(buttonByText(ui.queue.element, 'Load'))
expect(api.getJobDetail).toHaveBeenCalledWith('running')
expect(extractWorkflow).toHaveBeenCalled()
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
expect(app.nodeOutputs).toEqual({ node: { images: ['image.png'] } })
await click(buttonByText(ui.queue.element, 'Cancel'))
expect(api.interrupt).toHaveBeenCalledWith('running')
await click(buttonByText(ui.queue.element, 'Delete'))
expect(api.deleteItem).toHaveBeenCalledWith('queue', 'pending')
await click(buttonByText(ui.queue.element, 'Clear Queue'))
expect(api.clearItems).toHaveBeenCalledWith('queue')
await click(buttonByText(ui.queue.element, 'Refresh'))
expect(api.getQueue).toHaveBeenCalled()
})
it('skips loading queue items when job details are unavailable', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [{ id: 'running', priority: 1 }],
Pending: []
} as never)
vi.mocked(api.getJobDetail).mockResolvedValue(null as never)
const ui = new ComfyUI(app)
await ui.queue.show()
await click(buttonByText(ui.queue.element, 'Load'))
expect(extractWorkflow).not.toHaveBeenCalled()
expect(app.loadGraphData).not.toHaveBeenCalled()
})
it('loads queue item workflows without outputs', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [{ id: 'running', priority: 1 }],
Pending: []
} as never)
vi.mocked(api.getJobDetail).mockResolvedValue({ id: 'running' } as never)
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
const ui = new ComfyUI(app)
await ui.queue.show()
await click(buttonByText(ui.queue.element, 'Load'))
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
expect(app.nodeOutputs).toBeUndefined()
})
it('loads history in reverse order', async () => {
vi.mocked(api.getHistory).mockResolvedValue([
{ id: 'old', priority: 1 },
{ id: 'new', priority: 2 }
] as never)
const ui = new ComfyUI(app)
await ui.history.show()
expect(ui.history.element.textContent).toContain('2: LoadDelete')
expect(ui.history.element.textContent?.indexOf('2:')).toBeLessThan(
ui.history.element.textContent?.indexOf('1:') ?? Number.MAX_SAFE_INTEGER
)
})
it('updates queue status and auto-queues when enabled', () => {
const ui = new ComfyUI(app)
ui.autoQueueEnabled = true
ui.autoQueueMode = 'instant'
ui.lastQueueSize = 1
ui.batchCount = 3
ui.setStatus({ exec_info: { queue_remaining: 0 } })
expect(app.queuePrompt).toHaveBeenCalledWith(0, 3)
expect(ui.queueSize.textContent).toBe('Queue size: 0')
expect(ui.lastQueueSize).toBe(3)
ui.setStatus(null)
expect(ui.queueSize.textContent).toBe('Queue size: ERR')
})
it('does not auto-queue while a prior execution error is present', () => {
const ui = new ComfyUI(app)
ui.autoQueueEnabled = true
ui.autoQueueMode = 'instant'
ui.lastQueueSize = 1
Object.assign(app, { lastExecutionError: new Error('failed') })
ui.setStatus({ exec_info: { queue_remaining: 0 } })
expect(app.queuePrompt).not.toHaveBeenCalled()
expect(ui.lastQueueSize).toBe(0)
})
it('tracks graph changes for change-mode auto queueing', () => {
const ui = new ComfyUI(app)
const graphChanged = vi
.mocked(api.addEventListener)
.mock.calls.find(([eventName]) => eventName === 'graphChanged')?.[1]
if (!graphChanged) throw new Error('Missing graphChanged listener')
ui.autoQueueEnabled = true
ui.autoQueueMode = 'change'
ui.lastQueueSize = 1
graphChanged(fromAny(new CustomEvent('graphChanged')))
expect(ui.graphHasChanged).toBe(true)
ui.lastQueueSize = 0
graphChanged(fromAny(new CustomEvent('graphChanged')))
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
expect(ui.graphHasChanged).toBe(false)
})
it('wires primary menu buttons to app and command actions', async () => {
const ui = new ComfyUI(app)
await click(buttonByText(document, 'Queue Prompt'))
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
await click(buttonByText(document, 'Queue Front'))
expect(app.queuePrompt).toHaveBeenCalledWith(-1, 1)
await click(buttonByText(document, 'Save'))
await click(buttonByText(document, 'Save (API Format)'))
await click(buttonByText(document, 'Refresh'))
await click(buttonByText(document, 'Clipspace'))
await click(buttonByText(document, 'Clear'))
await click(buttonByText(document, 'Load Default'))
await click(buttonByText(document, 'Reset View'))
expect(app.refreshComboInNodes).toHaveBeenCalled()
expect(app.openClipspace).toHaveBeenCalled()
expect(app.clean).toHaveBeenCalled()
expect(app.loadGraphData).toHaveBeenCalledWith()
expect(ui.menuContainer.style.display).toBe('none')
})
it('wires file input and legacy option controls', async () => {
const ui = new ComfyUI(app)
const file = new File(['{}'], 'workflow.json', {
type: 'application/json'
})
const fileInput = document.getElementById(
'comfy-file-input'
) as HTMLInputElement
Object.defineProperty(fileInput, 'files', {
configurable: true,
value: [file]
})
await Promise.resolve(
fileInput.onchange?.call(fileInput, new Event('change'))
)
expect(app.handleFile).toHaveBeenCalledWith(file, 'file_button')
expect(fileInput.value).toBe('')
const range = document.getElementById(
'batchCountInputRange'
) as HTMLInputElement
const number = document.getElementById(
'batchCountInputNumber'
) as HTMLInputElement
range.value = '4'
const extraOptionsCheckbox = document.querySelector(
'label input[type="checkbox"]'
) as HTMLInputElement
extraOptionsCheckbox.checked = true
extraOptionsCheckbox.onchange?.call(
extraOptionsCheckbox,
fromPartial({ srcElement: extraOptionsCheckbox })
)
expect(ui.batchCount).toBe(4)
expect(document.getElementById('extraOptions')?.style.display).toBe('block')
number.value = '7'
number.oninput?.call(number, fromPartial({ target: number }))
expect(range.value).toBe('7')
range.value = '9'
range.oninput?.call(range, fromPartial({ srcElement: range }))
expect(number.value).toBe('9')
const autoQueueCheckbox = document.getElementById(
'autoQueueCheckbox'
) as HTMLInputElement
autoQueueCheckbox.checked = true
autoQueueCheckbox.onchange?.call(
autoQueueCheckbox,
fromPartial({ target: autoQueueCheckbox })
)
expect(ui.autoQueueEnabled).toBe(true)
extraOptionsCheckbox.checked = false
extraOptionsCheckbox.onchange?.call(
extraOptionsCheckbox,
fromPartial({ srcElement: extraOptionsCheckbox })
)
expect(ui.batchCount).toBe(1)
expect(ui.autoQueueEnabled).toBe(false)
expect(document.getElementById('extraOptions')?.style.display).toBe('none')
})
it('toggles queue visibility through the menu button', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [],
Pending: []
} as never)
const ui = new ComfyUI(app)
await click(buttonByText(document, 'View Queue'))
expect(ui.queue.visible).toBe(true)
await click(buttonByText(document, 'Close'))
expect(ui.queue.visible).toBe(false)
})
it('does not clear or load defaults when confirmation is declined', async () => {
mockSettingStore.get.mockReturnValue(true)
vi.stubGlobal(
'confirm',
vi.fn(() => false)
)
try {
new ComfyUI(app)
await click(buttonByText(document, 'Clear'))
await click(buttonByText(document, 'Load Default'))
expect(app.clean).not.toHaveBeenCalled()
expect(app.loadGraphData).not.toHaveBeenCalled()
} finally {
vi.unstubAllGlobals()
}
})
it('persists manual menu dragging', () => {
Object.defineProperty(document.body, 'clientWidth', {
configurable: true,
value: 1000
})
Object.defineProperty(document.body, 'clientHeight', {
configurable: true,
value: 800
})
const ui = new ComfyUI(app)
ui.menuContainer.style.display = 'block'
Object.defineProperty(ui.menuContainer, 'clientWidth', {
configurable: true,
value: 100
})
Object.defineProperty(ui.menuContainer, 'clientHeight', {
configurable: true,
value: 80
})
Object.defineProperty(ui.menuContainer, 'offsetLeft', {
configurable: true,
get: () => 700
})
Object.defineProperty(ui.menuContainer, 'offsetTop', {
configurable: true,
get: () => 20
})
const handle = ui.menuContainer.querySelector('.drag-handle') as HTMLElement
handle.onmousedown?.(
new MouseEvent('mousedown', { clientX: 10, clientY: 10 })
)
document.onmousemove?.(
new MouseEvent('mousemove', { clientX: 20, clientY: 30 })
)
document.onmouseup?.(new MouseEvent('mouseup'))
expect(ui.menuContainer.classList.contains('comfy-menu-manual-pos')).toBe(
true
)
expect(ui.menuContainer.style.right).toBe('190px')
expect(localStorage.getItem('Comfy.MenuPosition')).toBe(
JSON.stringify({ x: 700, y: 20 })
)
expect(document.onmousemove).toBeNull()
expect(document.onmouseup).toBeNull()
})
})

View File

@@ -0,0 +1,199 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { ComfyApp } from '@/scripts/app'
vi.mock('../../ui', () => ({
$el: (tag: string, props?: Record<string, unknown>, children?: Node[]) => {
const [tagName, ...classes] = tag.split('.')
const element = document.createElement(tagName)
if (classes.length) element.classList.add(...classes)
if (props) {
const listeners = Object.entries(props).filter(([key]) =>
key.startsWith('on')
)
for (const [key, listener] of listeners) {
if (typeof listener === 'function') {
element.addEventListener(key.slice(2), listener as EventListener)
}
}
}
if (children) element.append(...children)
return element
}
}))
vi.mock('../../utils', () => ({
prop: <T>(
target: object,
name: string,
defaultValue: T,
onChanged?: (currentValue: T, previousValue: T) => void
) => {
let currentValue: T
Object.defineProperty(target, name, {
get() {
return currentValue
},
set(newValue: T) {
const previousValue = currentValue
currentValue = newValue
onChanged?.(currentValue, previousValue)
}
})
;(target as Record<string, T>)[name] = defaultValue
return defaultValue
}
}))
import { ComfyButton } from './button'
class MockPopup extends EventTarget {
element = document.createElement('div')
open = false
toggle = vi.fn(() => {
this.open = !this.open
this.dispatchEvent(new CustomEvent('change'))
})
}
function mockApp(settingValue: boolean) {
let listener: (() => void) | undefined
const app = {
ui: {
settings: {
getSettingValue: vi.fn(() => settingValue),
addEventListener: vi.fn((_event: string, callback: () => void) => {
listener = callback
})
}
}
} as unknown as ComfyApp
return {
app,
setSettingValue(value: boolean) {
settingValue = value
listener?.()
}
}
}
describe('ComfyButton', () => {
beforeEach(() => {
document.body.replaceChildren()
})
it('renders icon, content, tooltip, enabled state, and click action', () => {
const action = vi.fn()
const button = new ComfyButton({
icon: 'play',
overIcon: 'pause',
iconSize: 18,
content: 'Run',
tooltip: 'Queue prompt',
enabled: false,
classList: { primary: true, hiddenClass: false },
action
})
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
expect(button.contentElement.textContent).toBe('Run')
expect(button.element.title).toBe('Queue prompt')
expect(button.element.getAttribute('aria-label')).toBe('Queue prompt')
expect(button.element.classList.contains('primary')).toBe(true)
expect(button.element.classList.contains('disabled')).toBe(true)
expect((button.element as HTMLButtonElement).disabled).toBe(true)
button.enabled = true
button.element.dispatchEvent(new MouseEvent('mouseenter'))
expect(button.iconElement.className).toBe('mdi mdi-pause mdi-18px')
button.element.dispatchEvent(new MouseEvent('mouseleave'))
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
button.element.dispatchEvent(new MouseEvent('click'))
expect(action).toHaveBeenCalledWith(expect.any(MouseEvent), button)
})
it('supports HTMLElement content and removing tooltip text', () => {
const button = new ComfyButton({ content: 'Text', tooltip: 'Hint' })
const content = document.createElement('strong')
content.textContent = 'Element'
button.content = content
button.tooltip = ''
expect(button.contentElement.firstElementChild).toBe(content)
expect(button.element.hasAttribute('title')).toBe(false)
})
it('updates the hover icon when overIcon changes while hovered', () => {
const button = new ComfyButton({ icon: 'play' })
button.element.dispatchEvent(new MouseEvent('mouseenter'))
button.overIcon = 'pause'
expect(button.iconElement.className).toBe('mdi mdi-pause')
})
it('hides and shows from a visibility setting', () => {
const settings = mockApp(false)
const button = new ComfyButton({
app: settings.app,
visibilitySetting: {
id: 'Comfy.UseNewMenu',
showValue: true
}
})
expect(button.hidden).toBe(true)
expect(button.element.classList.contains('hidden')).toBe(true)
settings.setSettingValue(true)
expect(button.hidden).toBe(false)
expect(button.element.classList.contains('hidden')).toBe(false)
})
it('toggles click popups and reflects popup open state in classes', () => {
const popup = new MockPopup()
const button = new ComfyButton({ icon: 'dots' }).withPopup(fromAny(popup))
button.element.dispatchEvent(new MouseEvent('click'))
expect(popup.toggle).toHaveBeenCalledOnce()
expect(button.element.classList.contains('popup-open')).toBe(true)
popup.toggle()
expect(button.element.classList.contains('popup-closed')).toBe(true)
})
it('opens hover popups while either the button or popup is hovered', () => {
const popup = new MockPopup()
const button = new ComfyButton({ icon: 'dots' }).withPopup(
fromAny(popup),
'hover'
)
button.element.dispatchEvent(new MouseEvent('mouseenter'))
expect(popup.open).toBe(true)
popup.element.dispatchEvent(new MouseEvent('mouseenter'))
button.element.dispatchEvent(new MouseEvent('mouseleave'))
expect(popup.open).toBe(true)
popup.element.dispatchEvent(new MouseEvent('mouseleave'))
expect(popup.open).toBe(false)
})
it('does not click-toggle a hover popup while hovered', () => {
const popup = new MockPopup()
const button = new ComfyButton({ icon: 'dots' }).withPopup(
fromAny(popup),
'hover'
)
button.element.dispatchEvent(new MouseEvent('mouseenter'))
button.element.dispatchEvent(new MouseEvent('click'))
expect(popup.toggle).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,284 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
type ElChild = Node | string
type ElInput = Record<string, unknown> | ElChild | ElChild[]
function appendChildren(element: HTMLElement, children: ElChild | ElChild[]) {
const list = Array.isArray(children) ? children : [children]
for (const child of list) {
element.append(child)
}
}
vi.mock('../../ui', () => ({
$el: (tag: string, propsOrChildren?: ElInput, children?: ElChild[]) => {
const [tagName, ...classes] = tag.split('.')
const element = document.createElement(tagName)
if (classes.length) element.classList.add(...classes)
if (
propsOrChildren instanceof Node ||
typeof propsOrChildren === 'string' ||
Array.isArray(propsOrChildren)
) {
appendChildren(element, propsOrChildren)
} else if (propsOrChildren) {
for (const [key, value] of Object.entries(propsOrChildren)) {
if (key === '$' && typeof value === 'function') {
value(element)
} else if (key === 'parent' && value instanceof HTMLElement) {
value.append(element)
} else if (key === 'textContent') {
element.textContent = String(value)
} else if (key === 'ariaLabel') {
element.setAttribute('aria-label', String(value))
} else if (key === 'ariaHasPopup') {
element.setAttribute('aria-haspopup', String(value))
} else if (key === 'type') {
element.setAttribute('type', String(value))
} else if (
key.toLowerCase().startsWith('on') &&
typeof value === 'function'
) {
element.addEventListener(
key.slice(2).toLowerCase(),
value as EventListener
)
}
}
}
if (children) appendChildren(element, children)
return element
}
}))
vi.mock('../../utils', () => ({
prop: <T>(
target: object,
name: string,
defaultValue: T,
onChanged?: (currentValue: T, previousValue: T) => void
) => {
let currentValue: T
Object.defineProperty(target, name, {
get() {
return currentValue
},
set(newValue: T) {
const previousValue = currentValue
currentValue = newValue
onChanged?.(currentValue, previousValue)
}
})
;(target as Record<string, T>)[name] = defaultValue
return defaultValue
}
}))
vi.mock('../utils', () => ({
applyClasses: (
element: HTMLElement,
classList: string | Record<string, boolean>,
...baseClasses: string[]
) => {
element.className = baseClasses.join(' ')
if (typeof classList === 'string') {
element.classList.add(...classList.split(' ').filter(Boolean))
} else {
for (const [className, enabled] of Object.entries(classList)) {
element.classList.toggle(className, enabled)
}
}
},
toggleElement:
<T>(
element: HTMLElement,
{
onShow
}: {
onShow?: (element: HTMLElement, value: T) => void
} = {}
) =>
(value: T) => {
element.hidden = !value
if (value) onShow?.(element, value)
}
}))
import { ComfyAsyncDialog } from './asyncDialog'
import { ComfyButton } from './button'
import { ComfyButtonGroup } from './buttonGroup'
import { ComfyPopup } from './popup'
import { ComfySplitButton } from './splitButton'
function targetWithRect(rect: DOMRect) {
const target = document.createElement('button')
vi.spyOn(target, 'getBoundingClientRect').mockReturnValue(rect)
document.body.append(target)
return target
}
describe('ComfyPopup and related UI components', () => {
beforeEach(() => {
document.body.replaceChildren()
vi.restoreAllMocks()
})
it('opens, positions, updates children and classes, and closes from escape', () => {
const target = targetWithRect(
DOMRect.fromRect({ x: 10, y: 20, width: 80, height: 30 })
)
const child = document.createElement('span')
const popup = new ComfyPopup(
{ target, classList: { menu: true, hidden: false } },
child
)
const open = vi.fn()
const close = vi.fn()
const change = vi.fn()
popup.addEventListener('open', open)
popup.addEventListener('close', close)
popup.addEventListener('change', change)
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
DOMRect.fromRect({ height: 20 })
)
popup.open = true
expect(open).toHaveBeenCalledOnce()
expect(change).toHaveBeenCalledOnce()
expect(popup.element).toHaveClass('open')
expect(popup.element.style.getPropertyValue('--left')).toBe('10px')
expect(popup.element.style.getPropertyValue('--bottom')).toBe('35px')
const nextChild = document.createElement('strong')
popup.children = [nextChild]
popup.classList = 'extra'
expect(popup.element.firstElementChild).toBe(nextChild)
expect(popup.element).toHaveClass('comfyui-popup', 'left', 'extra')
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Enter' }))
expect(popup.open).toBe(true)
const escape = new KeyboardEvent('keydown', {
key: 'Escape',
cancelable: true
})
window.dispatchEvent(escape)
expect(close).toHaveBeenCalledOnce()
expect(escape.defaultPrevented).toBe(true)
expect(popup.open).toBe(false)
})
it('handles outside clicks, target clicks, and relative right positioning', () => {
const target = targetWithRect(
DOMRect.fromRect({ x: 100, y: 40, width: 60, height: 25 })
)
const container = document.createElement('section')
document.body.append(container)
const popup = new ComfyPopup({
target,
container,
position: 'relative',
horizontal: 'right'
})
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
DOMRect.fromRect({ height: 100 })
)
Object.defineProperty(popup.element, 'clientWidth', {
configurable: true,
value: 40
})
popup.open = true
target.dispatchEvent(new MouseEvent('click', { bubbles: true }))
expect(popup.open).toBe(true)
const outside = document.createElement('div')
document.body.append(outside)
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
expect(popup.open).toBe(false)
expect(popup.element.style.getPropertyValue('--left')).toBe('0px')
expect(popup.element.style.getPropertyValue('--top')).toBe('25px')
})
it('keeps outside clicks open when target clicks are not ignored', () => {
const target = targetWithRect(DOMRect.fromRect({ height: 10 }))
const popup = new ComfyPopup({
target,
ignoreTarget: false,
closeOnEscape: false
})
const outside = document.createElement('div')
document.body.append(outside)
popup.toggle()
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' }))
expect(popup.open).toBe(true)
})
it('renders split button popup items and updates button groups', () => {
const primary = new ComfyButton({ content: 'Queue' })
const itemButton = new ComfyButton({ content: 'Queue front' })
const rawItem = document.createElement('button')
rawItem.textContent = 'Queue back'
const split = new ComfySplitButton(
{
primary,
mode: 'hover',
horizontal: 'right',
position: 'absolute'
},
itemButton,
rawItem
)
expect(split.element).toHaveClass('comfyui-split-button', 'hover')
expect(split.popup.element).toHaveTextContent('Queue front')
expect(split.popup.element).toHaveTextContent('Queue back')
expect(document.body).toContainElement(split.popup.element)
const group = new ComfyButtonGroup(primary)
group.append(itemButton)
group.insert(fromAny(rawItem), 1)
expect(group.element.children).toHaveLength(3)
expect(group.remove(itemButton)).toEqual([itemButton])
expect(group.remove(itemButton)).toBeUndefined()
expect(group.element.children).toHaveLength(2)
})
it('resolves async dialogs from buttons, close events, and prompt actions', async () => {
const dialog = new ComfyAsyncDialog<number>([
{ text: 'Seven', value: 7 },
'Fallback'
])
const promise = dialog.show('Pick one')
dialog.element.querySelector<HTMLButtonElement>('button')?.click()
await expect(promise).resolves.toBe(7)
const closePromise = dialog.show(document.createElement('em'))
dialog.element.dispatchEvent(new Event('close'))
await expect(closePromise).resolves.toBeNull()
const promptPromise = ComfyAsyncDialog.prompt({
title: 'Confirm',
message: 'Continue?',
actions: [{ text: 'Yes', value: 'yes' }]
})
const prompt = Array.from(document.querySelectorAll('dialog')).at(-1)
expect(prompt).toHaveTextContent('Confirm')
prompt?.querySelector<HTMLButtonElement>('button')?.click()
await expect(promptPromise).resolves.toBe('yes')
expect(document.body).not.toContainElement(prompt ?? null)
})
})

View File

@@ -0,0 +1,69 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import { ComfyDialog } from './dialog'
const mocks = vi.hoisted(() => ({
$el: (
selector: string,
propsOrChildren?: Record<string, unknown> | Node[],
maybeChildren?: Node[]
) => {
const [tag, ...classes] = selector.split('.')
const element = document.createElement(tag || 'div')
element.classList.add(...classes.filter(Boolean))
const children = Array.isArray(propsOrChildren)
? propsOrChildren
: maybeChildren
if (propsOrChildren && !Array.isArray(propsOrChildren)) {
for (const [key, value] of Object.entries(propsOrChildren)) {
if (key === 'parent' && value instanceof Node) {
value.appendChild(element)
} else if (key === '$' && typeof value === 'function') {
value(element)
} else {
Reflect.set(element, key, value)
}
}
}
element.append(...(children ?? []))
return element
}
}))
vi.mock('../ui', () => ({
$el: mocks.$el
}))
describe('ComfyDialog', () => {
afterEach(() => {
document.body.replaceChildren()
})
it('shows string and element content and closes through the default button', () => {
const dialog = new ComfyDialog()
dialog.show('<strong>Hello</strong>')
expect(dialog.element.style.display).toBe('flex')
expect(dialog.textElement.innerHTML).toBe('<strong>Hello</strong>')
dialog.element.querySelector('button')?.click()
expect(dialog.element.style.display).toBe('none')
const first = document.createElement('span')
const second = document.createElement('em')
dialog.show([first, second])
expect([...dialog.textElement.children]).toEqual([first, second])
})
it('uses supplied custom buttons', () => {
const button = document.createElement('button')
const dialog = new ComfyDialog('section', [button])
expect(dialog.element.tagName).toBe('SECTION')
expect(dialog.element.querySelector('button')).toBe(button)
})
})

View File

@@ -0,0 +1,216 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import { DraggableList } from './draggableList'
function createList(itemCount: number) {
const container = document.createElement('div')
const items = Array.from({ length: itemCount }, (_, index) => {
const item = document.createElement('div')
item.className = 'item'
item.dataset.index = String(index)
const handle = document.createElement('button')
handle.className = 'drag-handle'
item.append(handle)
container.append(item)
return item
})
document.body.append(container)
return { container, items }
}
function setRect(element: Element, top: number, height = 20) {
return vi
.spyOn(element, 'getBoundingClientRect')
.mockReturnValue(new DOMRect(0, top, 100, height))
}
function defineScrollMetrics(
container: HTMLElement,
scrollHeight: number,
clientHeight: number
) {
Object.defineProperty(container, 'scrollHeight', {
configurable: true,
value: scrollHeight
})
Object.defineProperty(container, 'clientHeight', {
configurable: true,
value: clientHeight
})
}
function mouseDragEvent(
target: Element,
overrides: Partial<MouseEvent> = {}
): MouseEvent {
return {
button: 0,
clientX: 0,
clientY: 0,
preventDefault: vi.fn(),
target,
...overrides
} satisfies Partial<MouseEvent> as MouseEvent
}
describe('DraggableList', () => {
afterEach(() => {
document.body.replaceChildren()
vi.restoreAllMocks()
vi.unstubAllGlobals()
})
it('ignores missing containers and non-primary drag starts', () => {
const listWithoutContainer = new DraggableList(null, '.item')
const { container, items } = createList(1)
const list = new DraggableList(container, '.item')
list.dragStart(
mouseDragEvent(items[0].querySelector('.drag-handle')!, { button: 1 })
)
list.dragEnd()
expect(listWithoutContainer.listContainer).toBeNull()
expect(list.draggableItem).toBeUndefined()
})
it('starts from a handle, scrolls downward, and reorders upward', () => {
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
callback(0)
return 1
})
const { container, items } = createList(3)
const scrollBy = vi.fn((_left: number, top: number) => {
container.scrollTop += top
})
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
defineScrollMetrics(container, 120, 80)
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
new DOMRect(0, 0, 100, 80)
)
setRect(items[0], 0)
setRect(items[1], 30)
setRect(items[2], 60)
const list = new DraggableList(container, '.item')
const dragStart = vi.fn()
const dragEnd = vi.fn()
list.addEventListener('dragstart', dragStart)
list.addEventListener('dragend', dragEnd)
list.dragStart(
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
clientX: 10,
clientY: 70
})
)
list.drag(
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
clientX: 20,
clientY: 100
})
)
items[1].dataset.isToggled = ''
list.dragEnd()
expect(dragStart).toHaveBeenCalledOnce()
expect(dragEnd).toHaveBeenCalledOnce()
expect(scrollBy).toHaveBeenCalledWith(0, 10)
expect([...container.children]).toEqual([items[0], items[2], items[1]])
expect(items[0].classList.contains('is-idle')).toBe(true)
expect(items[1].classList.contains('is-idle')).toBe(true)
})
it('supports touch coordinates, upward scrolling, and downward reorder', () => {
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
callback(0)
return 1
})
const { container, items } = createList(3)
const scrollBy = vi.fn((_left: number, top: number) => {
container.scrollTop += top
})
container.scrollTop = 10
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
defineScrollMetrics(container, 120, 80)
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
new DOMRect(0, 20, 100, 80)
)
setRect(items[0], 0)
setRect(items[1], 30)
setRect(items[2], 60)
const list = new DraggableList(container, '.item')
const touchStart = {
button: 0,
clientX: 0,
clientY: 0,
preventDefault: vi.fn(),
target: items[0].querySelector('.drag-handle')!,
touches: [{ clientX: 5, clientY: 30 }]
} as unknown as TouchEvent
const touchMove = {
clientX: 0,
clientY: 0,
preventDefault: vi.fn(),
target: items[0].querySelector('.drag-handle')!,
touches: [{ clientX: 8, clientY: 10 }]
} as unknown as TouchEvent
list.dragStart(touchStart)
list.drag(touchMove)
items[1].dataset.isToggled = ''
list.dragEnd()
expect(scrollBy).toHaveBeenCalledWith(0, -10)
expect([...container.children]).toEqual([items[1], items[0], items[2]])
})
it('updates idle item state around the dragged item midpoint', () => {
const { container, items } = createList(3)
const list = new DraggableList(container, '.item')
const state = list as unknown as {
items: HTMLElement[]
draggableItem: HTMLElement
}
state.items = items
state.draggableItem = items[1]
list.itemsGap = 5
items[0].classList.add('is-idle')
items[1].classList.add('is-idle')
items[2].classList.add('is-idle')
items[0].dataset.isAbove = ''
const draggedRect = setRect(items[1], -10)
setRect(items[0], 0)
setRect(items[2], 60)
list.updateIdleItemsStateAndPosition()
expect(items[0].dataset.isToggled).toBe('')
expect(items[0].style.transform).toBe('translateY(25px)')
expect(items[2].style.transform).toBe('')
draggedRect.mockReturnValue(new DOMRect(0, 100, 100, 20))
list.updateIdleItemsStateAndPosition()
expect(items[0].dataset.isToggled).toBeUndefined()
expect(items[2].dataset.isToggled).toBe('')
expect(items[2].style.transform).toBe('translateY(-25px)')
})
it('uses zero gap for short lists and disposes listeners', () => {
const { container } = createList(1)
const list = new DraggableList(container, '.item')
const off = vi.fn()
const disposableList = list as unknown as { off: Array<() => void> }
disposableList.off = [off]
list.setItemsGap()
list.dispose()
expect(list.itemsGap).toBe(0)
expect(off).toHaveBeenCalledOnce()
})
})

View File

@@ -0,0 +1,33 @@
import { describe, expect, it } from 'vitest'
import { calculateImageGrid } from './imagePreview'
function createImage(width: number, height: number) {
const img = document.createElement('img')
Object.defineProperty(img, 'naturalWidth', {
configurable: true,
value: width
})
Object.defineProperty(img, 'naturalHeight', {
configurable: true,
value: height
})
return img
}
describe('imagePreview', () => {
it('calculates the highest-area grid', () => {
const images = [
createImage(100, 100),
createImage(100, 100),
createImage(100, 100)
]
expect(calculateImageGrid(images, 300, 120)).toMatchObject({
cellWidth: 100,
cellHeight: 100,
cols: 3,
rows: 1
})
})
})

View File

@@ -0,0 +1,86 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import { toggleSwitch } from './toggleSwitch'
const mocks = vi.hoisted(() => ({
$el: (
selector: string,
propsOrChildren?: Record<string, unknown> | Node[] | Node,
maybeChildren?: Node[] | Node
) => {
const [tag, ...classes] = selector.split('.')
const element = document.createElement(tag || 'div')
element.classList.add(...classes.filter(Boolean))
const children = Array.isArray(propsOrChildren)
? propsOrChildren
: propsOrChildren instanceof Node
? [propsOrChildren]
: Array.isArray(maybeChildren)
? maybeChildren
: maybeChildren instanceof Node
? [maybeChildren]
: []
if (
propsOrChildren &&
!(propsOrChildren instanceof Node) &&
!Array.isArray(propsOrChildren)
) {
for (const [key, value] of Object.entries(propsOrChildren)) {
Reflect.set(element, key, value)
}
}
element.append(...children)
return element
}
}))
vi.mock('../ui', () => ({
$el: mocks.$el
}))
describe('toggleSwitch', () => {
afterEach(() => {
document.body.replaceChildren()
})
it('selects the first item when none is preselected', () => {
const onChange = vi.fn()
const container = toggleSwitch('mode', ['first', 'second'], { onChange })
const labels = [...container.querySelectorAll('label')]
const inputs = [...container.querySelectorAll('input')]
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(true)
expect((inputs[0] as HTMLInputElement).checked).toBe(true)
expect(onChange).toHaveBeenCalledWith({
item: 'first',
prev: undefined
})
})
it('moves selection and reports the previous item', () => {
const onChange = vi.fn()
const container = toggleSwitch(
'mode',
[
{ text: 'first', tooltip: 'First option' },
{ text: 'second', value: '2' }
],
{ onChange }
)
const labels = [...container.querySelectorAll('label')]
const secondInput = labels[1].querySelector('input') as HTMLInputElement
secondInput.onchange?.(new Event('change'))
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(false)
expect(labels[1].classList.contains('comfy-toggle-selected')).toBe(true)
expect(labels[0].title).toBe('First option')
expect(secondInput.value).toBe('2')
expect(onChange).toHaveBeenLastCalledWith({
item: { text: 'second', value: '2' },
prev: { text: 'first', tooltip: 'First option', value: 'first' }
})
})
})

View File

@@ -0,0 +1,45 @@
import { describe, expect, it, vi } from 'vitest'
import { applyClasses, toggleElement } from './utils'
describe('ui utils', () => {
it('applies string, array, object, and required classes', () => {
const element = document.createElement('div')
applyClasses(element, 'one two', 'required')
expect([...element.classList]).toEqual(['one', 'two', 'required'])
applyClasses(element, ['three', 'four'])
expect([...element.classList]).toEqual(['three', 'four'])
applyClasses(element, { five: true, six: false, seven: true })
expect([...element.classList]).toEqual(['five', 'seven'])
applyClasses(element, null as unknown as string)
expect(element.className).toBe('')
})
it('toggles an element through a placeholder', () => {
const parent = document.createElement('div')
const element = document.createElement('span')
const onHide = vi.fn()
const onShow = vi.fn()
parent.append(element)
const toggle = toggleElement(element, { onHide, onShow })
toggle(false)
expect(parent.firstChild).toBeInstanceOf(Comment)
expect(onHide).toHaveBeenCalledWith(element)
toggle(true)
expect(parent.firstChild).toBe(element)
expect(onShow).toHaveBeenCalledWith(element, true)
toggle('visible')
expect(onShow).toHaveBeenCalledWith(element, 'visible')
toggle(false)
expect(parent.firstChild).toBeInstanceOf(Comment)
expect(onHide).toHaveBeenCalledTimes(2)
})
})

127
src/scripts/utils.test.ts Normal file
View File

@@ -0,0 +1,127 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import {
addStylesheet,
clone,
getStorageValue,
prop,
setStorageValue
} from './utils'
interface LinkAttrs {
href: string
onerror: (error: Event) => void
onload: () => void
parent: HTMLElement
rel: string
type: string
}
const mocks = vi.hoisted(() => ({
api: {
clientId: null as string | null,
initialClientId: null as string | null
},
$el: vi.fn()
}))
vi.mock('./api', () => ({
api: mocks.api
}))
vi.mock('./ui', () => ({
$el: mocks.$el
}))
function lastLinkAttrs() {
return mocks.$el.mock.calls.at(-1)?.[1] as LinkAttrs
}
describe('scripts utils', () => {
afterEach(() => {
localStorage.clear()
sessionStorage.clear()
mocks.api.clientId = null
mocks.api.initialClientId = null
mocks.$el.mockReset()
vi.unstubAllGlobals()
})
it('clones with structuredClone and falls back to JSON cloning', () => {
const source = { nested: { value: 1 } }
expect(clone(source)).toEqual(source)
vi.stubGlobal(
'structuredClone',
vi.fn(() => {
throw new Error('unsupported')
})
)
const cloned = clone(source)
cloned.nested.value = 2
expect(cloned).toEqual({ nested: { value: 2 } })
expect(source).toEqual({ nested: { value: 1 } })
})
it('adds stylesheets from script and relative URLs', async () => {
const scriptPromise = addStylesheet('/extensions/example.js')
lastLinkAttrs().onload()
await expect(scriptPromise).resolves.toBeUndefined()
expect(lastLinkAttrs()).toMatchObject({
href: '/extensions/example.css',
parent: document.head,
rel: 'stylesheet',
type: 'text/css'
})
const cssPromise = addStylesheet('theme.css', 'https://example.com/base/')
lastLinkAttrs().onload()
await expect(cssPromise).resolves.toBeUndefined()
expect(lastLinkAttrs().href).toBe('https://example.com/base/theme.css')
})
it('rejects when stylesheet loading fails', async () => {
const promise = addStylesheet('missing.css', 'https://example.com/')
const error = new Event('error')
lastLinkAttrs().onerror(error)
await expect(promise).rejects.toBe(error)
})
it('defines an observable property with the supplied default', () => {
const target = {}
const onChanged = vi.fn()
expect(prop(target, 'mode', 'initial', onChanged)).toBe('initial')
Object.assign(target, { mode: 'next' })
expect((target as { mode: string }).mode).toBe('next')
expect(onChanged).toHaveBeenCalledWith('next', undefined, target, 'mode')
})
it('uses client-scoped storage before local fallback', () => {
mocks.api.clientId = 'client-1'
setStorageValue('setting', 'client-value')
sessionStorage.removeItem('setting:client-1')
expect(getStorageValue('setting')).toBe('client-value')
expect(localStorage.getItem('setting')).toBe('client-value')
sessionStorage.setItem('setting:client-1', 'session-value')
expect(getStorageValue('setting')).toBe('session-value')
})
it('uses initial client id when the current client id is unavailable', () => {
mocks.api.initialClientId = 'initial-1'
setStorageValue('setting', 'initial-value')
expect(sessionStorage.getItem('setting:initial-1')).toBe('initial-value')
expect(getStorageValue('setting')).toBe('initial-value')
})
})

356
src/scripts/widgets.test.ts Normal file
View File

@@ -0,0 +1,356 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type {
IBaseWidget,
IComboWidget
} from '@/lib/litegraph/src/types/widgets'
import type { InputSpec } from '@/schemas/nodeDefSchema'
const mockSettingGet = vi.hoisted(() => vi.fn())
const mockNextValueForLinkedTarget = vi.hoisted(() => vi.fn())
const mockIsComboWidget = vi.hoisted(() => vi.fn())
const mockTransformInputSpecV1ToV2 = vi.hoisted(() => vi.fn())
function v2WidgetConstructor(kind: string) {
return () => (_node: LGraphNode, inputSpec: { name: string }) => ({
name: `${kind}:${inputSpec.name}`,
options: { minNodeSize: [20, 30] }
})
}
vi.mock('@/i18n', () => ({
t: (key: string) => `translated:${key}`
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => ({
get: mockSettingGet
})
}))
vi.mock('@/lib/litegraph/src/litegraph', () => ({
isComboWidget: mockIsComboWidget
}))
vi.mock('./valueControl', () => ({
nextValueForLinkedTarget: mockNextValueForLinkedTarget
}))
vi.mock('@/schemas/nodeDef/migration', () => ({
transformInputSpecV1ToV2: mockTransformInputSpecV1ToV2
}))
vi.mock('@/core/graph/widgets/dynamicWidgets', () => ({
dynamicWidgets: {
DYNAMIC: () => ({
widget: { name: 'dynamic', options: {} },
minWidth: 1,
minHeight: 1
})
}
}))
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
() => ({ useBooleanWidget: v2WidgetConstructor('BOOLEAN') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget',
() => ({ useBoundingBoxWidget: v2WidgetConstructor('BOUNDING_BOX') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useCurveWidget',
() => ({ useCurveWidget: v2WidgetConstructor('CURVE') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useChartWidget',
() => ({ useChartWidget: v2WidgetConstructor('CHART') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useColorWidget',
() => ({ useColorWidget: v2WidgetConstructor('COLOR') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useComboWidget',
() => ({ useComboWidget: v2WidgetConstructor('COMBO') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
() => ({ useFloatWidget: v2WidgetConstructor('FLOAT') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useGalleriaWidget',
() => ({ useGalleriaWidget: v2WidgetConstructor('GALLERIA') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxesWidget',
() => ({ useBoundingBoxesWidget: v2WidgetConstructor('BOUNDING_BOXES') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useColorsWidget',
() => ({ useColorsWidget: v2WidgetConstructor('COLORS') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useImageCompareWidget',
() => ({ useImageCompareWidget: v2WidgetConstructor('IMAGECOMPARE') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useImageUploadWidget',
() => ({
useImageUploadWidget: () => (_node: LGraphNode, inputName: string) => ({
widget: { name: `IMAGEUPLOAD:${inputName}`, options: {} },
minWidth: 5,
minHeight: 6
})
})
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useIntWidget',
() => ({ useIntWidget: v2WidgetConstructor('INT') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useMarkdownWidget',
() => ({ useMarkdownWidget: v2WidgetConstructor('MARKDOWN') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/usePainterWidget',
() => ({ usePainterWidget: v2WidgetConstructor('PAINTER') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useRangeWidget',
() => ({ useRangeWidget: v2WidgetConstructor('RANGE') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
() => ({ useStringWidget: v2WidgetConstructor('STRING') })
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useTextareaWidget',
() => ({ useTextareaWidget: v2WidgetConstructor('TEXTAREA') })
)
vi.mock('./domWidget', () => ({}))
vi.mock('./errorNodeWidgets', () => ({}))
import {
ComfyWidgets,
IS_CONTROL_WIDGET,
addValueControlWidget,
addValueControlWidgets,
isValidWidgetType,
updateControlWidgetLabel
} from './widgets'
// `linkedWidgets`, `beforeQueued`, and `afterQueued` already exist on
// IBaseWidget (via the litegraph augmentation), so no extra members needed.
type MockWidget = IBaseWidget
function makeTargetWidget(overrides: Partial<MockWidget> = {}): MockWidget {
return {
name: 'seed',
value: 1,
callback: vi.fn(),
options: {},
linkedWidgets: [],
computedDisabled: false,
...overrides
} as MockWidget
}
function makeNode(inputs: LGraphNode['inputs'] = []) {
const widgets: MockWidget[] = []
const node = {
id: 42,
inputs,
addWidget: vi.fn(
(
type: string,
name: string,
value: string,
callback: () => void,
options: Record<string, unknown>
) => {
const widget: MockWidget = fromAny({
type,
name,
value,
callback,
options,
linkedWidgets: [],
computedDisabled: false
})
widgets.push(widget)
return widget
}
)
}
return { node: node as unknown as LGraphNode, widgets }
}
describe('widgets', () => {
beforeEach(() => {
vi.clearAllMocks()
mockSettingGet.mockReturnValue('after')
mockNextValueForLinkedTarget.mockReturnValue('next')
mockIsComboWidget.mockImplementation(
(widget: MockWidget) => widget.type === 'combo'
)
mockTransformInputSpecV1ToV2.mockImplementation(
(_inputData: InputSpec, options: { name: string }) => ({
name: options.name
})
)
})
it('updates the control widget label from the configured run mode', () => {
const widget = makeTargetWidget()
mockSettingGet.mockReturnValue('before')
updateControlWidgetLabel(widget)
expect(widget.label).toBe('translated:g.control_before_generate')
mockSettingGet.mockReturnValue('after')
updateControlWidgetLabel(widget)
expect(widget.label).toBe('translated:g.control_after_generate')
})
it('adds control and filter widgets for combo targets', () => {
const { node, widgets } = makeNode()
const target = makeTargetWidget({ type: 'combo', computedDisabled: true })
const result = addValueControlWidgets(node, target, '', undefined, [
'COMBO',
{
control_prefix: 'custom'
}
] as unknown as InputSpec)
expect(result).toHaveLength(2)
expect(widgets[0].name).toBe('custom control_after_generate')
expect(widgets[0].value).toBe('randomize')
expect((widgets[0] as IComboWidget).options.values).toContain(
'increment-wrap'
)
expect(widgets[0][IS_CONTROL_WIDGET]).toBe(true)
expect(widgets[0].disabled).toBe(true)
expect(widgets[1].name).toBe('custom control_filter_list')
expect(widgets[1].disabled).toBe(true)
})
it('uses explicit option names and can skip the combo filter widget', () => {
const { node, widgets } = makeNode()
const target = makeTargetWidget({ type: 'combo' })
addValueControlWidgets(
node,
target,
'fixed',
{
addFilterList: false,
controlAfterGenerateName: 'mode'
},
['COMBO', {}] as unknown as InputSpec
)
expect(widgets).toHaveLength(1)
expect(widgets[0].name).toBe('mode')
})
it('applies linked target values after queueing in after mode', () => {
const { node, widgets } = makeNode()
const target = makeTargetWidget()
addValueControlWidgets(node, target)
widgets[0].afterQueued?.({ isPartialExecution: true })
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith({
target,
linkedWidgets: target.linkedWidgets,
nodeId: 42,
isPartialExecution: true
})
expect(target.value).toBe('next')
expect(target.callback).toHaveBeenCalledWith('next')
})
it('waits until the second beforeQueued call in before mode', () => {
mockSettingGet.mockReturnValue('before')
const { node, widgets } = makeNode()
const target = makeTargetWidget()
addValueControlWidgets(node, target)
widgets[0].beforeQueued?.()
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
widgets[0].beforeQueued?.({ isPartialExecution: false })
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith(
expect.objectContaining({ isPartialExecution: false })
)
})
it('does not change the target when the target has a linked input or no next value', () => {
const { node, widgets } = makeNode([
{ widget: { name: 'seed' }, link: 1 }
] as LGraphNode['inputs'])
const target = makeTargetWidget()
addValueControlWidgets(node, target)
widgets[0].afterQueued?.()
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
const unlinked = makeNode()
mockNextValueForLinkedTarget.mockReturnValue(undefined)
addValueControlWidgets(unlinked.node, target)
unlinked.widgets[0].afterQueued?.()
expect(target.callback).not.toHaveBeenCalled()
})
it('uses the legacy single control widget name from input data before widgetName', () => {
const { node, widgets } = makeNode()
const target = makeTargetWidget()
const result = addValueControlWidget(
node,
target,
'fixed',
undefined,
'fallback',
[
'INT',
{
control_after_generate: 'from_input_data'
}
] as unknown as InputSpec
)
expect(result).toBe(widgets[0])
expect(widgets[0].name).toBe('from_input_data')
})
it('exposes transformed widget constructors and type validation', () => {
const { node } = makeNode()
const intWidget = ComfyWidgets.INT(
node,
'value',
['INT', {}] as unknown as InputSpec,
{} as never
)
expect(intWidget.widget.name).toBe('INT:value')
expect(intWidget.minWidth).toBe(20)
expect(intWidget.minHeight).toBe(30)
expect(
ComfyWidgets.IMAGEUPLOAD(node, 'image', ['IMAGE', {}], {} as never)
).toMatchObject({
widget: { name: 'IMAGEUPLOAD:image' },
minWidth: 5,
minHeight: 6
})
expect(isValidWidgetType('INT')).toBe(true)
expect(isValidWidgetType('DYNAMIC')).toBe(true)
expect(isValidWidgetType('missing')).toBe(false)
})
})

View File

@@ -38,6 +38,12 @@ describe('useAudioService', () => {
name: 'test-audio-123.wav'
}
async function freshService() {
vi.resetModules()
const audioServiceModule = await import('@/services/audioService')
return audioServiceModule.useAudioService()
}
beforeEach(() => {
vi.clearAllMocks()
@@ -90,12 +96,41 @@ describe('useAudioService', () => {
)
mockRegister.mockRejectedValueOnce(error)
await service.registerWavEncoder()
const isolatedService = await freshService()
await isolatedService.registerWavEncoder()
await isolatedService.registerWavEncoder()
expect(mockConnect).toHaveBeenCalledTimes(0)
expect(mockRegister).toHaveBeenCalledTimes(0)
expect(mockConnect).toHaveBeenCalledTimes(1)
expect(mockRegister).toHaveBeenCalledTimes(1)
expect(console.error).not.toHaveBeenCalled()
})
it('should log encoder registration errors', async () => {
const error = new Error('Encoder failed')
mockRegister.mockRejectedValueOnce(error)
const isolatedService = await freshService()
await isolatedService.registerWavEncoder()
expect(console.error).toHaveBeenCalledWith(
'Audio Service Error (encoder):',
'Failed to register WAV encoder',
error
)
})
it('should log non-Error encoder registration failures', async () => {
mockRegister.mockRejectedValueOnce('Encoder failed')
const isolatedService = await freshService()
await isolatedService.registerWavEncoder()
expect(console.error).toHaveBeenCalledWith(
'Audio Service Error (encoder):',
'Failed to register WAV encoder',
'Encoder failed'
)
})
})
describe('stopAllTracks', () => {

View File

@@ -0,0 +1,118 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { setupAutoQueueHandler } from '@/services/autoQueueService'
type ApiEvent = 'graphChanged'
type ApiListener = () => void
type Subscription = () => Promise<void> | void
const {
listeners,
queueCountStore,
queueSettingsStore,
appState,
addEventListener,
isInstantRunningMode
} = vi.hoisted(() => ({
listeners: new Map<ApiEvent, ApiListener>(),
queueCountStore: {
count: 0,
subscription: undefined as Subscription | undefined,
$subscribe: vi.fn((_callback: Subscription) => {
queueCountStore.subscription = _callback
})
},
queueSettingsStore: {
mode: 'manual',
batchCount: 1
},
appState: {
lastExecutionError: null as unknown,
queuePrompt: vi.fn()
},
addEventListener: vi.fn((event: ApiEvent, listener: ApiListener) => {
listeners.set(event, listener)
}),
isInstantRunningMode: vi.fn((mode: string) => mode === 'instant')
}))
vi.mock('@/scripts/api', () => ({
api: { addEventListener }
}))
vi.mock('@/scripts/app', () => ({
app: appState
}))
vi.mock('@/stores/queueStore', () => ({
isInstantRunningMode,
useQueuePendingTaskCountStore: () => queueCountStore,
useQueueSettingsStore: () => queueSettingsStore
}))
beforeEach(() => {
listeners.clear()
queueCountStore.count = 0
queueCountStore.subscription = undefined
queueCountStore.$subscribe.mockClear()
queueSettingsStore.mode = 'manual'
queueSettingsStore.batchCount = 1
appState.lastExecutionError = null
appState.queuePrompt.mockReset().mockResolvedValue(undefined)
addEventListener.mockClear()
isInstantRunningMode
.mockClear()
.mockImplementation((mode) => mode === 'instant')
})
describe('setupAutoQueueHandler', () => {
it('queues immediately on graph changes when change mode is idle', () => {
queueSettingsStore.mode = 'change'
queueSettingsStore.batchCount = 3
setupAutoQueueHandler()
listeners.get('graphChanged')?.()
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 3)
})
it('queues after pending work drains in instant mode', async () => {
queueSettingsStore.mode = 'instant'
queueSettingsStore.batchCount = 2
queueCountStore.count = 0
setupAutoQueueHandler()
await queueCountStore.subscription?.()
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 2)
})
it('queues after a changed graph drains from an active queue', async () => {
queueSettingsStore.mode = 'change'
queueCountStore.count = 1
setupAutoQueueHandler()
await queueCountStore.subscription?.()
listeners.get('graphChanged')?.()
expect(appState.queuePrompt).not.toHaveBeenCalled()
queueCountStore.count = 0
await queueCountStore.subscription?.()
expect(appState.queuePrompt).toHaveBeenCalledTimes(1)
})
it('does not requeue while work remains or the last run failed', async () => {
queueSettingsStore.mode = 'instant'
queueCountStore.count = 1
setupAutoQueueHandler()
await queueCountStore.subscription?.()
appState.lastExecutionError = { message: 'failed' }
queueCountStore.count = 0
await queueCountStore.subscription?.()
expect(appState.queuePrompt).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,363 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { DEFAULT_DARK_COLOR_PALETTE } from '@/constants/coreColorPalettes'
import {
LGraphCanvas,
LiteGraph,
RenderShape
} from '@/lib/litegraph/src/litegraph'
import type { CompletedPalette, Palette } from '@/schemas/colorPaletteSchema'
const mockCanvas = vi.hoisted(() => ({
default_connection_color_byType: {} as Record<string, string>,
node_title_color: '',
default_link_color: '',
background_image: '',
clear_background_color: '',
_pattern: 'pattern' as string | undefined,
setDirty: vi.fn()
}))
const mockColorPaletteStore = vi.hoisted(() => ({
customPalettes: {} as Record<string, unknown>,
palettesLookup: {} as Record<string, unknown>,
completedActivePalette: undefined as unknown,
activePaletteId: 'dark',
addCustomPalette: vi.fn(),
deleteCustomPalette: vi.fn(),
completePalette: vi.fn()
}))
const mockSettingStore = vi.hoisted(() => ({
get: vi.fn(),
set: vi.fn()
}))
const mockNodeDefStore = vi.hoisted(() => ({
nodeDataTypes: new Set(['IMAGE', 'MISSING'])
}))
const mockDownloadBlob = vi.hoisted(() => vi.fn())
const mockUploadFile = vi.hoisted(() => vi.fn())
vi.mock('@/scripts/app', () => ({
app: { canvas: mockCanvas }
}))
vi.mock('@/stores/workspace/colorPaletteStore', () => ({
useColorPaletteStore: () => mockColorPaletteStore
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => mockSettingStore
}))
vi.mock('@/stores/nodeDefStore', () => ({
useNodeDefStore: () => mockNodeDefStore
}))
vi.mock('@/base/common/downloadUtil', () => ({
downloadBlob: mockDownloadBlob
}))
vi.mock('@/scripts/utils', () => ({
uploadFile: mockUploadFile
}))
vi.mock('@/composables/useErrorHandling', () => ({
useErrorHandling: () => ({
wrapWithErrorHandling: <T>(action: T) => action,
wrapWithErrorHandlingAsync: <T>(action: T) => action
})
}))
import { useColorPaletteService } from './colorPaletteService'
const validCustomPalette = {
id: 'custom',
name: 'Custom',
colors: {
node_slot: {},
litegraph_base: {},
comfy_base: {}
}
} satisfies Palette
function makeCompletedPalette(id = 'custom'): CompletedPalette {
const palette = structuredClone(
DEFAULT_DARK_COLOR_PALETTE
) as CompletedPalette
palette.id = id
palette.name = 'Custom'
palette.colors.node_slot.IMAGE = '#123456'
palette.colors.litegraph_base.NODE_TITLE_COLOR = '#abcdef'
palette.colors.litegraph_base.LINK_COLOR = '#fedcba'
palette.colors.litegraph_base.BACKGROUND_IMAGE = 'grid.png'
palette.colors.litegraph_base.CLEAR_BACKGROUND_COLOR = '#010203'
palette.colors.litegraph_base.NODE_DEFAULT_SHAPE = 'legacy'
palette.colors.comfy_base['fg-color'] = '#111111'
palette.colors.comfy_base['bg-color'] = '#222222'
delete palette.colors.comfy_base['contrast-mix-color']
return palette
}
describe('useColorPaletteService', () => {
beforeEach(() => {
vi.clearAllMocks()
mockCanvas.default_connection_color_byType = {}
mockCanvas.node_title_color = ''
mockCanvas.default_link_color = ''
mockCanvas.background_image = ''
mockCanvas.clear_background_color = ''
mockCanvas._pattern = 'pattern'
LGraphCanvas.link_type_colors = {}
mockSettingStore.get.mockReturnValue('')
mockSettingStore.set.mockResolvedValue(undefined)
mockColorPaletteStore.customPalettes = { custom: validCustomPalette }
mockColorPaletteStore.palettesLookup = { custom: validCustomPalette }
mockColorPaletteStore.completedActivePalette = makeCompletedPalette()
mockColorPaletteStore.activePaletteId = 'dark'
mockColorPaletteStore.completePalette.mockReturnValue(
makeCompletedPalette()
)
document.documentElement.style.cssText = ''
document.documentElement.style.setProperty(
'--color-datatype-MISSING',
'#ffffff'
)
})
it('adds valid custom palettes and persists the custom palette map', async () => {
const service = useColorPaletteService()
await service.addCustomColorPalette(validCustomPalette)
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
validCustomPalette
)
expect(mockSettingStore.set).toHaveBeenCalledWith(
'Comfy.CustomColorPalettes',
mockColorPaletteStore.customPalettes
)
})
it('rejects invalid custom palettes before mutating the store', async () => {
const service = useColorPaletteService()
await expect(service.addCustomColorPalette({} as Palette)).rejects.toThrow(
'Invalid color palette against zod schema'
)
expect(mockColorPaletteStore.addCustomPalette).not.toHaveBeenCalled()
})
it('deletes custom palettes and persists the custom palette map', async () => {
const service = useColorPaletteService()
await service.deleteCustomColorPalette('custom')
expect(mockColorPaletteStore.deleteCustomPalette).toHaveBeenCalledWith(
'custom'
)
expect(mockSettingStore.set).toHaveBeenCalledWith(
'Comfy.CustomColorPalettes',
mockColorPaletteStore.customPalettes
)
})
it('loads palette colors into litegraph, Vue CSS variables, and canvas state', async () => {
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(mockCanvas.default_connection_color_byType.IMAGE).toBe('#123456')
expect(LGraphCanvas.link_type_colors.IMAGE).toBe('#123456')
expect(
document.documentElement.style.getPropertyValue('--color-datatype-IMAGE')
).toBe('#123456')
expect(
document.documentElement.style.getPropertyValue(
'--color-datatype-MISSING'
)
).toBe('')
expect(mockCanvas.node_title_color).toBe('#abcdef')
expect(mockCanvas.default_link_color).toBe('#fedcba')
expect(mockCanvas.background_image).toBe('grid.png')
expect(mockCanvas.clear_background_color).toBe('#010203')
expect(mockCanvas._pattern).toBeUndefined()
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.ROUND)
expect(warn).toHaveBeenCalledWith(
`litegraph_base.NODE_DEFAULT_SHAPE only accepts [${[
RenderShape.BOX,
RenderShape.ROUND,
RenderShape.CARD
].join(', ')}] but got legacy`
)
expect(document.documentElement.style.getPropertyValue('--fg-color')).toBe(
'#111111'
)
expect(
document.documentElement.style.getPropertyValue('--contrast-mix-color')
).toBe('var(--palette-contrast-mix-color)')
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
})
it('skips absent palette sections while still activating the palette', async () => {
const completedPalette = makeCompletedPalette()
mockColorPaletteStore.completePalette.mockReturnValue(
fromAny<CompletedPalette, unknown>({
...completedPalette,
colors: {
node_slot: undefined,
litegraph_base: completedPalette.colors.litegraph_base,
comfy_base: undefined
}
})
)
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(mockCanvas.node_title_color).toBe('#abcdef')
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
})
it('removes Vue node theme overrides for built-in palettes', async () => {
mockColorPaletteStore.palettesLookup = { dark: validCustomPalette }
mockColorPaletteStore.completePalette.mockReturnValue(
makeCompletedPalette('dark')
)
document.documentElement.style.setProperty(
'--component-node-border',
'#ffffff'
)
const service = useColorPaletteService()
await service.loadColorPalette('dark')
expect(
document.documentElement.style.getPropertyValue('--component-node-border')
).toBe('')
})
it('removes Vue node theme variables when completed palette values are absent', async () => {
const completedPalette = makeCompletedPalette()
// NODE_BOX_OUTLINE_COLOR is required on the completed palette type; the
// test needs it absent, so delete via Reflect to keep the type intact.
Reflect.deleteProperty(
completedPalette.colors.litegraph_base,
'NODE_BOX_OUTLINE_COLOR'
)
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
document.documentElement.style.setProperty(
'--component-node-border',
'#ffffff'
)
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(
document.documentElement.style.getPropertyValue('--component-node-border')
).toBe('')
})
it('preserves numeric LiteGraph node shapes without warning', async () => {
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
const completedPalette = makeCompletedPalette()
completedPalette.colors.litegraph_base.NODE_DEFAULT_SHAPE = RenderShape.CARD
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.CARD)
expect(warn).not.toHaveBeenCalled()
})
it('uses explicit optional comfy color values when present', async () => {
const completedPalette = makeCompletedPalette()
completedPalette.colors.comfy_base['contrast-mix-color'] = '#333333'
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(
document.documentElement.style.getPropertyValue('--contrast-mix-color')
).toBe('#333333')
})
it('uses a white splash background for light themes', async () => {
const completedPalette = makeCompletedPalette()
completedPalette.light_theme = true
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(localStorage.getItem('comfy-splash-bg')).toBe('#FFFFFF')
expect(localStorage.getItem('comfy-splash-fg')).toBe('#111111')
})
it('uses transparent canvas background and bg image CSS when a background image setting exists', async () => {
mockSettingStore.get.mockReturnValue('/custom/background.png')
const service = useColorPaletteService()
await service.loadColorPalette('custom')
expect(mockCanvas.clear_background_color).toBe('transparent')
expect(document.documentElement.style.getPropertyValue('--bg-img')).toBe(
"url('/custom/background.png')"
)
})
it('throws when loading or exporting an unknown palette', async () => {
mockColorPaletteStore.palettesLookup = {}
const service = useColorPaletteService()
await expect(service.loadColorPalette('missing')).rejects.toThrow(
'Color palette missing not found'
)
expect(() => service.exportColorPalette('missing')).toThrow(
'Color palette missing not found'
)
})
it('exports palette JSON by id', async () => {
const service = useColorPaletteService()
service.exportColorPalette('custom')
expect(mockDownloadBlob).toHaveBeenCalledOnce()
const [filename, blob] = mockDownloadBlob.mock.calls[0] as [string, Blob]
expect(filename).toBe('custom.json')
await expect(blob.text()).resolves.toContain('"id": "custom"')
})
it('imports palette JSON through the custom palette path', async () => {
mockUploadFile.mockResolvedValue({
text: () => Promise.resolve(JSON.stringify(validCustomPalette))
})
const service = useColorPaletteService()
const palette = await service.importColorPalette()
expect(mockUploadFile).toHaveBeenCalledWith('application/json')
expect(palette).toEqual(validCustomPalette)
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
validCustomPalette
)
})
it('returns the completed active palette from the store', () => {
const service = useColorPaletteService()
expect(service.getActiveColorPalette()).toBe(
mockColorPaletteStore.completedActivePalette
)
})
})

View File

@@ -18,7 +18,7 @@ import type {
} from '@/stores/dialogStore'
import type { ComponentAttrs } from 'vue-component-type-helpers'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
// Lazy loaders for dialogs - components are loaded on first use
@@ -442,9 +442,9 @@ export const useDialogService = () => {
})
}
async function showSubscriptionRequiredDialog(options?: {
reason?: SubscriptionDialogReason
}) {
async function showSubscriptionRequiredDialog(
options?: SubscriptionDialogOptions
) {
if (!isCloud || !window.__CONFIG__?.subscription_required) {
return
}

View File

@@ -0,0 +1,324 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type { AuthUserInfo } from '@/types/authTypes'
import type { ComfyExtension } from '@/types/comfy'
import { useExtensionService } from './extensionService'
const mockLoadDisabledExtensionNames = vi.hoisted(() => vi.fn())
const mockRegisterExtension = vi.hoisted(() => vi.fn())
const mockCaptureCoreExtensions = vi.hoisted(() => vi.fn())
const mockEnabledExtensions = vi.hoisted(() => ({
value: [] as ComfyExtension[]
}))
vi.mock('@/stores/extensionStore', () => ({
useExtensionStore: () => ({
loadDisabledExtensionNames: mockLoadDisabledExtensionNames,
registerExtension: mockRegisterExtension,
captureCoreExtensions: mockCaptureCoreExtensions,
get enabledExtensions() {
return mockEnabledExtensions.value
}
})
}))
const mockGetSetting = vi.hoisted(() => vi.fn())
const mockAddSetting = vi.hoisted(() => vi.fn())
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => ({
get: mockGetSetting,
addSetting: mockAddSetting
})
}))
const mockAddDefaultKeybinding = vi.hoisted(() => vi.fn())
vi.mock('@/platform/keybindings/keybindingStore', () => ({
useKeybindingStore: () => ({
addDefaultKeybinding: mockAddDefaultKeybinding
})
}))
vi.mock('@/platform/keybindings/keybinding', () => ({
KeybindingImpl: class KeybindingImpl {
constructor(readonly source: unknown) {}
}
}))
const mockLoadExtensionCommands = vi.hoisted(() => vi.fn())
vi.mock('@/stores/commandStore', () => ({
useCommandStore: () => ({
loadExtensionCommands: mockLoadExtensionCommands
})
}))
const mockLoadExtensionMenuCommands = vi.hoisted(() => vi.fn())
vi.mock('@/stores/menuItemStore', () => ({
useMenuItemStore: () => ({
loadExtensionMenuCommands: mockLoadExtensionMenuCommands
})
}))
const mockRegisterBottomPanelTabs = vi.hoisted(() => vi.fn())
vi.mock('@/stores/workspace/bottomPanelStore', () => ({
useBottomPanelStore: () => ({
registerExtensionBottomPanelTabs: mockRegisterBottomPanelTabs
})
}))
const mockRegisterCustomWidgets = vi.hoisted(() => vi.fn())
vi.mock('@/stores/widgetStore', () => ({
useWidgetStore: () => ({
registerCustomWidgets: mockRegisterCustomWidgets
})
}))
const mockToastErrorHandler = vi.hoisted(() => vi.fn())
vi.mock('@/composables/useErrorHandling', () => ({
useErrorHandling: () => ({
wrapWithErrorHandling:
<Args extends unknown[], Return>(fn: (...args: Args) => Return) =>
(...args: Args) =>
fn(...args),
wrapWithErrorHandlingAsync:
<Args extends unknown[], Return>(
fn: (...args: Args) => Return | Promise<Return>,
handler: (error: unknown) => void
) =>
async (...args: Args) => {
try {
return await fn(...args)
} catch (error) {
handler(error)
}
},
toastErrorHandler: mockToastErrorHandler
})
}))
const mockUserResolvedCallbacks = vi.hoisted(() => ({
values: [] as Array<(user: AuthUserInfo) => void>
}))
const mockTokenRefreshedCallbacks = vi.hoisted(() => ({
values: [] as Array<() => void>
}))
const mockUserLogoutCallbacks = vi.hoisted(() => ({
values: [] as Array<() => void>
}))
vi.mock('@/composables/auth/useCurrentUser', () => ({
useCurrentUser: () => ({
onUserResolved: (callback: (user: AuthUserInfo) => void) => {
mockUserResolvedCallbacks.values.push(callback)
},
onTokenRefreshed: (callback: () => void) => {
mockTokenRefreshedCallbacks.values.push(callback)
},
onUserLogout: (callback: () => void) => {
mockUserLogoutCallbacks.values.push(callback)
}
})
}))
const mockSetCurrentExtension = vi.hoisted(() => vi.fn())
vi.mock('@/lib/litegraph/src/contextMenuCompat', () => ({
legacyMenuCompat: {
setCurrentExtension: mockSetCurrentExtension
}
}))
const mockApp = vi.hoisted(() => ({ value: { name: 'app' } }))
vi.mock('@/scripts/app', () => ({
app: mockApp.value
}))
vi.mock('@/scripts/api', () => ({
api: {
getExtensions: vi.fn(),
fileURL: vi.fn((path: string) => path)
}
}))
describe('useExtensionService', () => {
beforeEach(() => {
vi.clearAllMocks()
mockEnabledExtensions.value = []
mockUserResolvedCallbacks.values = []
mockTokenRefreshedCallbacks.values = []
mockUserLogoutCallbacks.values = []
})
it('registers extension contributions across stores', async () => {
const widgets = { CustomWidget: vi.fn() }
const extension = fromAny<ComfyExtension, unknown>({
name: 'registration-extension',
keybindings: [{ commandId: 'command.one', combo: { key: 'K' } }],
commands: [{ id: 'command.one', label: 'Command One' }],
menuCommands: [{ path: ['File'], commands: ['command.one'] }],
settings: [{ id: 'setting.one', name: 'Setting One' }],
bottomPanelTabs: [{ id: 'tab.one', title: 'Tab One' }],
getCustomWidgets: vi.fn().mockResolvedValue(widgets)
})
const service = useExtensionService()
service.registerExtension(extension)
expect(mockRegisterExtension).toHaveBeenCalledWith(extension)
expect(mockAddDefaultKeybinding).toHaveBeenCalledWith(
expect.objectContaining({
source: { commandId: 'command.one', combo: { key: 'K' } }
})
)
expect(mockLoadExtensionCommands).toHaveBeenCalledWith(extension)
expect(mockLoadExtensionMenuCommands).toHaveBeenCalledWith(extension)
expect(mockAddSetting.mock.calls[0][0]).toEqual({
id: 'setting.one',
name: 'Setting One'
})
expect(mockRegisterBottomPanelTabs).toHaveBeenCalledWith(extension)
await vi.waitFor(() => {
expect(mockRegisterCustomWidgets).toHaveBeenCalledWith(widgets)
})
})
it('invokes auth lifecycle hooks through registered callbacks', async () => {
const onAuthUserResolved = vi.fn()
const onAuthTokenRefreshed = vi.fn()
const onAuthUserLogout = vi.fn()
const extension = fromAny<ComfyExtension, unknown>({
name: 'auth-extension',
onAuthUserResolved,
onAuthTokenRefreshed,
onAuthUserLogout
})
const user = fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
const service = useExtensionService()
service.registerExtension(extension)
mockUserResolvedCallbacks.values[0](user)
mockTokenRefreshedCallbacks.values[0]()
mockUserLogoutCallbacks.values[0]()
await vi.waitFor(() => {
expect(onAuthUserResolved).toHaveBeenCalledWith(user, mockApp.value)
expect(onAuthTokenRefreshed).toHaveBeenCalled()
expect(onAuthUserLogout).toHaveBeenCalled()
})
})
it('reports auth hook errors through the toast handler', async () => {
const error = new Error('auth failed')
const extension = fromAny<ComfyExtension, unknown>({
name: 'failing-auth-extension',
onAuthUserResolved: vi.fn(() => {
throw error
})
})
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
const service = useExtensionService()
service.registerExtension(extension)
mockUserResolvedCallbacks.values[0](
fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
)
await vi.waitFor(() => {
expect(mockToastErrorHandler).toHaveBeenCalledWith(error)
})
expect(consoleError).toHaveBeenCalledWith(
'[Extension Auth Hook Error]',
expect.objectContaining({
extension: 'failing-auth-extension',
hook: 'onAuthUserResolved',
error
})
)
consoleError.mockRestore()
})
it('invokes synchronous extension methods and keeps failures isolated', () => {
const getSelectionToolboxCommands = vi.fn(() => ['command.one'])
const failingGetSelectionToolboxCommands = vi.fn(() => {
throw new Error('menu failed')
})
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
mockEnabledExtensions.value = [
fromAny<ComfyExtension, unknown>({
name: 'working-extension',
getSelectionToolboxCommands
}),
fromAny<ComfyExtension, unknown>({
name: 'non-function-extension',
getSelectionToolboxCommands: ['not callable']
}),
fromAny<ComfyExtension, unknown>({
name: 'failing-extension',
getSelectionToolboxCommands: failingGetSelectionToolboxCommands
}),
{ name: 'missing-method-extension' }
]
const service = useExtensionService()
const results = service.invokeExtensions(
'getSelectionToolboxCommands',
fromAny<LGraphNode, unknown>({ id: 1 })
)
expect(results).toEqual([['command.one']])
expect(getSelectionToolboxCommands).toHaveBeenCalledWith(
fromAny<LGraphNode, unknown>({ id: 1 }),
mockApp.value
)
expect(consoleError).toHaveBeenCalledWith(
"Error calling extension 'failing-extension' method 'getSelectionToolboxCommands'",
expect.objectContaining({ error: expect.any(Error) }),
expect.objectContaining({
extension: expect.objectContaining({ name: 'failing-extension' })
}),
expect.objectContaining({
args: [fromAny<LGraphNode, unknown>({ id: 1 })]
})
)
consoleError.mockRestore()
})
it('tracks current extension around async setup callbacks', async () => {
const setup = vi.fn().mockResolvedValue('setup-result')
const failingSetup = vi.fn().mockRejectedValue(new Error('setup failed'))
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
mockEnabledExtensions.value = [
fromAny<ComfyExtension, unknown>({
name: 'setup-extension',
setup
}),
fromAny<ComfyExtension, unknown>({
name: 'non-function-extension',
setup: true
}),
fromAny<ComfyExtension, unknown>({
name: 'failing-setup-extension',
setup: failingSetup
}),
{ name: 'missing-method-extension' }
]
const service = useExtensionService()
const results = await service.invokeExtensionsAsync('setup')
expect(results).toEqual(['setup-result', undefined, undefined, undefined])
expect(mockSetCurrentExtension.mock.calls.map((call) => call[0])).toEqual([
'setup-extension',
'failing-setup-extension',
null,
null
])
expect(consoleError).toHaveBeenCalledWith(
"Error calling extension 'failing-setup-extension' method 'setup'",
expect.objectContaining({ error: expect.any(Error) }),
expect.objectContaining({
extension: expect.objectContaining({ name: 'failing-setup-extension' })
}),
expect.objectContaining({ args: [] })
)
consoleError.mockRestore()
})
})

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More