Compare commits

..

9 Commits

Author SHA1 Message Date
huang47
7141bda563 test: consolidate coverage suites and clean up superseded test scaffolding
Fold sibling suites back into their canonical files (app.test.ts,
litegraphService.test.ts), replace module-level global.fetch/global.URL
mocks with vi.stubGlobal, apply helper refactors (avif/gltf box
builders, searchAndReplace createGraph, treeUtil createNode hoist), and
drop tests superseded by rewritten variants.
2026-07-02 16:36:20 -07:00
huang47
74d4366994 test: cover the root-graph branch in widget promotion options test 2026-07-02 16:35:51 -07:00
huang47
c85c0585ab test: address review findings on new coverage tests
Exercise the real drain path in autoQueueService test, align
executionErrorStore no-op test with its fixture, reset ComfyApp
clipspace state between litegraphService tests, restore app.canvas and
global stubs in finally blocks, strengthen subgraphStore compact-detail
assertion, fix imprecise imageUtil test title.
2026-07-02 15:53:21 -07:00
huang47
804630fa02 test: restructure coverage additions to be purely additive vs main
Restore all pre-existing test lines; new coverage lands as pure
additions. Suites whose mock architecture conflicts with the original
files move to sibling files (app.core.test.ts,
litegraphService.core.test.ts). Rewrites/refactors of existing tests
and helpers are deferred to a follow-up PR.
2026-07-02 14:20:13 -07:00
huang47
7417864353 fix(tests): resolve typecheck errors in coverage tests
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-07-02 11:27:13 -07:00
huang47
1f26dc2b57 test: cover critical missed branches 2026-07-02 02:23:01 -07:00
Benjamin Lu
2ec2a0e091 feat: attribute payment intent through paywall, checkout, and top-up telemetry (#13363)
## Summary

Answers "why did this user want to pay?" by capturing the triggering
product moment at every paywall/upsell entry point and carrying it
through checkout and success telemetry.

## Changes

- **What**:
- Widen `SubscriptionDialogReason` from 4 coarse values to 13 grounded
intent sources (`subscribe_to_run`, `upgrade_to_add_credits`,
`invite_member_upsell`, `settings_billing_panel`, etc.)
- Fire `app:subscription_required_modal_opened` from
`useSubscriptionDialog` (the choke point all dialog variants pass
through) — the workspace/unified path previously emitted nothing; remove
the now-duplicate emitters in `useSubscription` and
`usePricingTableUrlLoader`
- Add `payment_intent_source` to
`BeginCheckoutMetadata`/`SubscriptionSuccessMetadata`, threaded via the
existing `reason` prop: dialog → `PricingTable` →
`performSubscriptionCheckout` → pending-attempt record, so legacy
`app:monthly_subscription_succeeded` carries intent alongside
`checkout_attempt_id`
- Fire `begin_checkout` on the workspace checkout path
(`useSubscriptionCheckout`, personal + team confirm) and the team
deep-link util — both previously emitted nothing; `tier` widened to
`TierKey | 'team'`
- Implement `trackBeginCheckout` in `PostHogTelemetryProvider` (was
GTM/host-only, so `begin_checkout` never reached PostHog)
- Thread `showSubscriptionDialog(options)` through the billing-context
adapters and pass a reason at ~14 call sites; add `source` to
`app:add_api_credit_button_clicked`

## Review Focus

- `modal_opened` now fires once per dialog actually shown, so a
free-tier user clicking Upgrade emits two events (free-tier dialog, then
pricing table) where the legacy path emitted one
- Intent is threaded explicitly via props/params rather than shared
state; `useSubscriptionCheckout` gained an optional second parameter
2026-07-02 03:11:21 +00:00
Mobeen Abdullah
9cf5c9a93f refactor(website): tidy customer story review nits (#13324)
## Summary

Small follow-up to #13289 applying two non-blocking review nits from
Alex's review.

## Changes

- **What**: drop the redundant `before:content-['']` on the
customer-story list bullet (Tailwind emits the empty `content`
automatically once another `before:` utility is present), and rename
`HEADER_OFFSET` to `HEADER_OFFSET_PX` in `ArticleNav` so the scroll
constants use consistent unit suffixes.

## Review Focus

Both changes are cosmetic with no behavior change. Confirmed in the
browser that the list bullet still renders identically (6px yellow dot)
without the explicit `content` utility.

## Notes from the #13289 review (left as-is here, open to discussion)

Three other comments from the review are intentionally not changed in
this PR; reasoning below so the decisions are on record:

- **`Category` type in `ArticleNav`**: kept the `ComponentProps<typeof
CategoryNav>` derivation. AGENTS.md says to derive component types via
`vue-component-type-helpers` rather than redefining them, so the current
form follows the styleguide. Happy to switch to a plain named type if
preferred.
- **Section ids in frontmatter vs the body `<Section>`**: kept the
`customers.content.test.ts` parity test. The short TOC labels live only
in frontmatter and Astro can't introspect the rendered MDX body to build
the nav, so the frontmatter `sections` list and the body anchor ids
can't be trivially deduplicated. A real fix would need a remark plugin
(larger, separate change). The test guards against silent drift in the
meantime.
- **`nextStory` throw**: left as a fail-loud, build-time invariant. The
slug always comes from the same `getStaticPaths` collection, so the
throw is effectively unreachable; it surfaces a future-refactor bug
loudly instead of linking to the wrong story.
2026-07-01 12:45:24 +00:00
jaeone94
9e5fb67b76 Show app mode run validation warning (#12557)
## Summary
Adds an app mode validation warning so users can see when a workflow has
errors before running and jump directly back to graph mode to review
them.

## Changes
- **What**: Adds a reusable app mode warning banner above the Run button
when the execution error store reports workflow errors, including
validation and missing asset states.
- **What**: Reuses the existing graph-error navigation flow so the
warning action switches out of app mode and opens the Errors panel in
graph mode.
- **What**: Updates the app mode Run button icon and accessible label in
the warning state while keeping the Run action non-blocking.
- **What**: Adds unit coverage for the warning render/accessibility
state and an E2E flow that triggers a validation failure, dismisses the
overlay, and opens graph errors from the app mode warning.
- **Breaking**: None.
- **Dependencies**: None.

## Review Focus
The warning intentionally mirrors graph mode behavior: it surfaces the
error state but does not prevent the user from clicking Run. This avoids
turning display-level validation signals into hard execution blockers.

The warning is driven by the existing `hasAnyError` aggregate, so
missing nodes, missing models, and missing media are included alongside
prompt/node/execution errors.

## Tests
- `pnpm format`
- `pnpm lint`
- `pnpm typecheck`
- `pnpm test:unit`
- `pnpm knip`
- `pnpm test:browser:local
browser_tests/tests/appModeValidationWarning.spec.ts`

## Screenshots

<img width="461" height="994" alt="스크린샷 2026-06-25 오후 7 00 55"
src="https://github.com/user-attachments/assets/f8fc20bf-d572-46b5-9fa4-312e7c4c8076"
/>
2026-07-01 15:24:45 +09:00
139 changed files with 17169 additions and 1867 deletions

View File

@@ -15,7 +15,7 @@ const { categories } = defineProps<{
const activeSection = ref(categories[0]?.value ?? '')
const HEADER_OFFSET = -144
const HEADER_OFFSET_PX = -144
const BOTTOM_THRESHOLD_PX = 4
const SCROLL_SAFETY_MS = 1500
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
const el = document.getElementById(id)
if (el) {
scrollTo(el, {
offset: HEADER_OFFSET,
offset: HEADER_OFFSET_PX,
duration: 0.8,
immediate: prefersReducedMotion(),
onComplete: clearScrollLock

View File

@@ -1,5 +1,5 @@
<li
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
>
<slot />
</li>

View File

@@ -0,0 +1,45 @@
{
"last_node_id": 9,
"last_link_id": 9,
"nodes": [
{
"id": 9,
"type": "SaveImage",
"pos": {
"0": 64,
"1": 104
},
"size": {
"0": 210,
"1": 58
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": null
}
],
"outputs": [],
"properties": {},
"widgets_values": ["ComfyUI"]
}
],
"links": [],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1,
"offset": [0, 0]
},
"linearData": {
"inputs": [],
"outputs": ["9"]
}
},
"version": 0.4
}

View File

@@ -34,6 +34,10 @@ export class AppModeHelper {
public readonly outputPlaceholder: Locator
/** The linear-mode widget list container (visible in app mode). */
public readonly linearWidgets: Locator
/** The validation warning shown above the app mode run button. */
public readonly validationWarning: Locator
/** The action that opens graph mode errors from the validation warning. */
public readonly viewErrorsInGraphButton: Locator
/** The PrimeVue Popover for the image picker (renders with role="dialog"). */
public readonly imagePickerPopover: Locator
/** The Run button in the app mode footer. */
@@ -92,13 +96,19 @@ export class AppModeHelper {
this.outputPlaceholder = this.page.getByTestId(
TestIds.builder.outputPlaceholder
)
this.linearWidgets = this.page.getByTestId('linear-widgets')
this.linearWidgets = this.page.getByTestId(TestIds.linear.widgetContainer)
this.validationWarning = this.page.getByTestId(
TestIds.linear.validationWarning
)
this.viewErrorsInGraphButton = this.validationWarning.getByTestId(
TestIds.linear.viewErrorsInGraph
)
this.imagePickerPopover = this.page
.getByRole('dialog')
.filter({ has: this.page.getByRole('button', { name: 'All' }) })
.first()
this.runButton = this.page
.getByTestId('linear-run-button')
.getByTestId(TestIds.linear.runButton)
.getByRole('button', { name: /run/i })
this.welcome = this.page.getByTestId(TestIds.appMode.welcome)
this.emptyWorkflowText = this.page.getByTestId(

View File

@@ -172,6 +172,9 @@ export const TestIds = {
mobileNavigation: 'linear-mobile-navigation',
mobileWorkflows: 'linear-mobile-workflows',
outputInfo: 'linear-output-info',
runButton: 'linear-run-button',
validationWarning: 'linear-validation-warning',
viewErrorsInGraph: 'linear-view-errors',
widgetContainer: 'linear-widgets'
},
builder: {

View File

@@ -0,0 +1,106 @@
import {
comfyExpect as expect,
comfyPageFixture as test
} from '@e2e/fixtures/ComfyPage'
import type { NodeError, PromptResponse } from '@/schemas/apiSchema'
import { ExecutionHelper } from '@e2e/fixtures/helpers/ExecutionHelper'
import { enableErrorsOverlay } from '@e2e/fixtures/helpers/ErrorsTabHelper'
import { TestIds } from '@e2e/fixtures/selectors'
const SAVE_IMAGE_NODE_ID = '9'
function buildSaveImageRequiredInputError(): NodeError {
return {
class_type: 'SaveImage',
dependent_outputs: [],
errors: [
{
type: 'required_input_missing',
message: 'Required input is missing: images',
details: '',
extra_info: { input_name: 'images' }
}
]
}
}
test.describe(
'App mode validation warning',
{ tag: ['@ui', '@workflow'] },
() => {
test.beforeEach(async ({ comfyPage }) => {
await enableErrorsOverlay(comfyPage)
await comfyPage.workflow.loadWorkflow('linear-validation-warning')
await comfyPage.appMode.toggleAppMode()
await expect(comfyPage.appMode.linearWidgets).toBeVisible()
})
test('opens graph errors from the app mode validation warning', async ({
comfyPage
}) => {
await expect(comfyPage.appMode.validationWarning).toBeHidden()
const exec = new ExecutionHelper(comfyPage)
await exec.mockValidationFailure({
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
})
await comfyPage.appMode.runButton.click()
const appModeOverlay = comfyPage.appMode.centerPanel.getByTestId(
TestIds.dialogs.errorOverlay
)
await expect(appModeOverlay).toBeHidden()
await expect(comfyPage.appMode.validationWarning).toBeVisible()
await expect(comfyPage.appMode.validationWarning).toContainText(
/Required input missing/i
)
await expect(comfyPage.appMode.viewErrorsInGraphButton).toBeVisible()
await comfyPage.appMode.viewErrorsInGraphButton.click()
await expect(comfyPage.appMode.linearWidgets).toBeHidden()
await expect(
comfyPage.page.getByTestId(TestIds.propertiesPanel.root)
).toBeVisible()
await expect(
comfyPage.page.getByTestId(TestIds.propertiesPanel.errorsTab)
).toBeVisible()
})
test('keeps the app mode run button enabled when the warning is visible', async ({
comfyPage
}) => {
const exec = new ExecutionHelper(comfyPage)
await exec.mockValidationFailure({
[SAVE_IMAGE_NODE_ID]: buildSaveImageRequiredInputError()
})
await comfyPage.appMode.runButton.click()
await expect(comfyPage.appMode.validationWarning).toBeVisible()
await expect(comfyPage.appMode.runButton).toBeEnabled()
let promptQueued = false
const mockResponse: PromptResponse = {
prompt_id: 'test-id',
node_errors: {},
error: ''
}
await comfyPage.page.route(
'**/api/prompt',
async (route) => {
promptQueued = true
await route.fulfill({
status: 200,
body: JSON.stringify(mockResponse)
})
},
{ times: 1 }
)
await comfyPage.appMode.runButton.click()
await expect.poll(() => promptQueued).toBe(true)
})
}
)

View File

@@ -1,5 +1,6 @@
import { expect } from '@playwright/test'
import { toLinkId } from '@/types/linkId'
import { toNodeId } from '@/types/nodeId'
import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage'
@@ -15,9 +16,10 @@ test.describe('Graph', { tag: ['@smoke', '@canvas'] }, () => {
await comfyPage.workflow.loadWorkflow('inputs/input_order_swap')
await expect
.poll(() =>
comfyPage.page.evaluate(() => {
return window.app!.graph!.links.get(1)?.target_slot
})
comfyPage.page.evaluate(
(linkId) => window.app!.graph!.links.get(linkId)?.target_slot,
toLinkId(1)
)
)
.toBe(1)
})

View File

@@ -3,6 +3,7 @@ import {
comfyPageFixture as test,
comfyExpect as expect
} from '@e2e/fixtures/ComfyPage'
import { TestIds } from '@e2e/fixtures/selectors'
test.describe('Linear Mode', { tag: '@ui' }, () => {
test('Displays linear controls when app mode active', async ({
@@ -16,7 +17,9 @@ test.describe('Linear Mode', { tag: '@ui' }, () => {
test('Run button visible in linear mode', async ({ comfyPage }) => {
await comfyPage.appMode.enterAppModeWithInputs([])
await expect(comfyPage.page.getByTestId('linear-run-button')).toBeVisible()
await expect(
comfyPage.page.getByTestId(TestIds.linear.runButton)
).toBeVisible()
})
test('Workflow info section visible', async ({ comfyPage }) => {

View File

@@ -0,0 +1,75 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
describe('runWhenGlobalIdle', () => {
beforeEach(() => {
vi.resetModules()
})
afterEach(() => {
vi.useRealTimers()
vi.unstubAllGlobals()
})
it('falls back to a timeout when idle callbacks are unavailable', async () => {
vi.useFakeTimers()
vi.stubGlobal('requestIdleCallback', undefined)
vi.stubGlobal('cancelIdleCallback', undefined)
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
const disposable = runWhenGlobalIdle(runner)
await vi.runAllTimersAsync()
expect(runner).toHaveBeenCalledOnce()
const deadline = runner.mock.calls[0][0]
expect(deadline.didTimeout).toBe(true)
expect(deadline.timeRemaining()).toBeGreaterThanOrEqual(0)
disposable.dispose()
disposable.dispose()
})
it('cancels fallback idle work before it runs', async () => {
vi.useFakeTimers()
vi.stubGlobal('requestIdleCallback', undefined)
vi.stubGlobal('cancelIdleCallback', undefined)
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
runWhenGlobalIdle(runner).dispose()
await vi.runAllTimersAsync()
expect(runner).not.toHaveBeenCalled()
})
it('uses native idle callbacks when available', async () => {
const requestIdleCallback = vi.fn(() => 42)
const cancelIdleCallback = vi.fn()
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
vi.stubGlobal('cancelIdleCallback', cancelIdleCallback)
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
const disposable = runWhenGlobalIdle(runner, 250)
expect(requestIdleCallback).toHaveBeenCalledWith(runner, { timeout: 250 })
disposable.dispose()
disposable.dispose()
expect(cancelIdleCallback).toHaveBeenCalledOnce()
expect(cancelIdleCallback).toHaveBeenCalledWith(42)
})
it('omits native idle timeout options when no timeout is supplied', async () => {
const requestIdleCallback = vi.fn(() => 7)
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
vi.stubGlobal('cancelIdleCallback', vi.fn())
const { runWhenGlobalIdle } = await import('./async')
const runner = vi.fn()
runWhenGlobalIdle(runner)
expect(requestIdleCallback).toHaveBeenCalledWith(runner, undefined)
})
})

View File

@@ -122,6 +122,22 @@ describe('downloadUtil', () => {
expect(createObjectURLSpy).not.toHaveBeenCalled()
})
it('throws for an empty URL', () => {
expect(() => downloadFile('')).toThrow(
'Invalid URL provided for download'
)
expect(fetchMock).not.toHaveBeenCalled()
expect(createObjectURLSpy).not.toHaveBeenCalled()
})
it('throws for a whitespace URL', () => {
expect(() => downloadFile(' ')).toThrow(
'Invalid URL provided for download'
)
expect(fetchMock).not.toHaveBeenCalled()
expect(createObjectURLSpy).not.toHaveBeenCalled()
})
it('should prefer custom filename over extracted filename', () => {
const testUrl =
'https://example.com/api/file?filename=extracted-image.jpg'

View File

@@ -4,6 +4,7 @@ import {
CREDITS_PER_USD,
COMFY_CREDIT_RATE_CENTS,
centsToCredits,
clampUsd,
creditsToCents,
creditsToUsd,
formatCredits,
@@ -43,4 +44,21 @@ describe('comfyCredits helpers', () => {
expect(formatCreditsFromUsd({ usd: 1, locale })).toBe('211.00')
expect(formatUsd({ value: 4.2, locale })).toBe('4.20')
})
test('formats with compatible fraction digit bounds', () => {
expect(
formatCredits({
value: 12.345,
locale: 'en-US',
numberOptions: { minimumFractionDigits: 4, maximumFractionDigits: 2 }
})
).toBe('12.35')
})
test('clamps USD purchase values into the supported range', () => {
expect(clampUsd(Number.NaN)).toBe(0)
expect(clampUsd(-5)).toBe(1)
expect(clampUsd(42)).toBe(42)
expect(clampUsd(5000)).toBe(1000)
})
})

View File

@@ -37,7 +37,7 @@
size="unset"
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
data-testid="error-overlay-see-errors"
@click="seeErrors"
@click="viewErrorsInGraph"
>
{{
appMode
@@ -67,31 +67,18 @@ import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
const { appMode = false } = defineProps<{ appMode?: boolean }>()
const { t } = useI18n()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
const canvasStore = useCanvasStore()
const { viewErrorsInGraph } = useViewErrorsInGraph()
const { isVisible, overlayMessage, overlayTitle } = useErrorOverlayState()
function dismiss() {
executionErrorStore.dismissErrorOverlay()
}
function seeErrors() {
canvasStore.linearMode = false
if (canvasStore.canvas) {
canvasStore.canvas.deselectAll()
canvasStore.updateSelectedItems()
}
rightSidePanelStore.openPanel('errors')
executionErrorStore.dismissErrorOverlay()
}
</script>

View File

@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}

View File

@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
const subscriptionDialog = useSubscriptionDialog()
function handleClick() {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
}
</script>

View File

@@ -1,5 +1,6 @@
import type { ComputedRef, Ref } from 'vue'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type {
BillingStatus,
@@ -75,9 +76,10 @@ export interface BillingActions {
*/
requireActiveSubscription: () => Promise<void>
/**
* Shows the subscription dialog.
* Shows the subscription dialog. Pass a reason so the paywall open and any
* downstream checkout stay attributed to the triggering product moment.
*/
showSubscriptionDialog: () => void
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
}
export interface BillingState {

View File

@@ -7,6 +7,7 @@ import {
getTierFeatures
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
PreviewSubscribeOptions,
SubscribeOptions
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
return activeContext.value.requireActiveSubscription()
}
function showSubscriptionDialog() {
return activeContext.value.showSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
return activeContext.value.showSubscriptionDialog(options)
}
return {

View File

@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
import { useAuthActions } from '@/composables/auth/useAuthActions'
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingStatus,
BillingSubscriptionStatus,
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
legacyShowSubscriptionDialog()
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
legacyShowSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
legacyShowSubscriptionDialog(options)
}
return {

View File

@@ -34,22 +34,17 @@ describe('useSelectionToolboxPosition', () => {
canvasStore = useCanvasStore()
})
function renderToolboxForSelection(
items: Iterable<Positionable>,
state: Partial<LGraphCanvas['state']> = {},
ds: Partial<LGraphCanvas['ds']> = {}
) {
function renderToolboxForSelection(item: Positionable) {
canvasStore.canvas = markRaw({
canvas: document.createElement('canvas'),
ds: {
offset: ds.offset ?? [0, 0],
scale: ds.scale ?? 1
offset: [0, 0],
scale: 1
},
selectedItems: new Set(items),
selectedItems: new Set([item]),
state: {
draggingItems: false,
selectionChanged: true,
...state
selectionChanged: true
}
} as Partial<LGraphCanvas> as LGraphCanvas)
@@ -74,7 +69,7 @@ describe('useSelectionToolboxPosition', () => {
group.pos = [100, 200]
group.size = [160, 80]
const { toolbox, unmount } = renderToolboxForSelection([group])
const { toolbox, unmount } = renderToolboxForSelection(group)
expect(toolbox.style.getPropertyValue('--tb-y')).toBe('190px')
unmount()
@@ -86,64 +81,11 @@ describe('useSelectionToolboxPosition', () => {
node.pos = [100, 200]
node.size = [160, 80]
const { toolbox, unmount } = renderToolboxForSelection([node])
const { toolbox, unmount } = renderToolboxForSelection(node)
expect(toolbox.style.getPropertyValue('--tb-y')).toBe(
`${190 - LiteGraph.NODE_TITLE_HEIGHT}px`
)
unmount()
})
it('does not set coordinates when selection is empty', () => {
const { toolbox, unmount } = renderToolboxForSelection([])
expect(toolbox.style.getPropertyValue('--tb-x')).toBe('')
expect(toolbox.style.getPropertyValue('--tb-y')).toBe('')
unmount()
})
it('does not set coordinates while selected items are being dragged', () => {
const group = new LGraphGroup('Group', 1)
group.pos = [100, 200]
group.size = [160, 80]
const { toolbox, unmount } = renderToolboxForSelection([group], {
draggingItems: true
})
expect(toolbox.style.getPropertyValue('--tb-x')).toBe('')
expect(toolbox.style.getPropertyValue('--tb-y')).toBe('')
unmount()
})
it('positions multiple selected items from their union bounds', () => {
const first = new LGraphGroup('First', 1)
first.pos = [100, 200]
first.size = [100, 40]
const second = new LGraphGroup('Second', 2)
second.pos = [300, 260]
second.size = [50, 40]
const { toolbox, unmount } = renderToolboxForSelection([first, second])
expect(toolbox.style.getPropertyValue('--tb-x')).toBe('270px')
expect(toolbox.style.getPropertyValue('--tb-y')).toBe('190px')
unmount()
})
it('applies canvas scale and offset to screen coordinates', () => {
const group = new LGraphGroup('Group', 1)
group.pos = [100, 200]
group.size = [100, 40]
const { toolbox, unmount } = renderToolboxForSelection(
[group],
{},
{ offset: [10, 20], scale: 2 }
)
expect(toolbox.style.getPropertyValue('--tb-x')).toBe('360px')
expect(toolbox.style.getPropertyValue('--tb-y')).toBe('420px')
unmount()
})
})

View File

@@ -1,7 +1,6 @@
import { fromPartial } from '@total-typescript/shoehorn'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { afterEach, describe, expect, it, vi } from 'vitest'
import { downloadFile, openFileInNewTab } from '@/base/common/downloadUtil'
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
import { createMockLGraphNode } from '@/utils/__tests__/litegraphTestUtils'
import { useImageMenuOptions } from './useImageMenuOptions'
@@ -20,11 +19,6 @@ vi.mock('@/stores/commandStore', () => ({
useCommandStore: () => ({ execute: vi.fn() })
}))
vi.mock('@/base/common/downloadUtil', () => ({
downloadFile: vi.fn(),
openFileInNewTab: vi.fn()
}))
function mockClipboard(clipboard: Partial<Clipboard> | undefined) {
Object.defineProperty(navigator, 'clipboard', {
value: clipboard,
@@ -33,15 +27,6 @@ function mockClipboard(clipboard: Partial<Clipboard> | undefined) {
})
}
function stubClipboardItem() {
vi.stubGlobal(
'ClipboardItem',
class ClipboardItemStub {
constructor(public readonly items: Record<string, Blob>) {}
}
)
}
function createImageNode(
overrides: Partial<LGraphNode> | Record<string, unknown> = {}
): LGraphNode {
@@ -60,13 +45,8 @@ function createImageNode(
}
describe('useImageMenuOptions', () => {
beforeEach(() => {
vi.clearAllMocks()
})
afterEach(() => {
vi.restoreAllMocks()
vi.unstubAllGlobals()
})
describe('getImageMenuOptions', () => {
@@ -202,147 +182,4 @@ describe('useImageMenuOptions', () => {
expect(node.pasteFiles).not.toHaveBeenCalled()
})
})
describe('image actions', () => {
it('opens the selected image without preview query params', () => {
const node = createImageNode()
node.imgs![0].src = 'http://localhost/test.png?preview=1&foo=bar'
const { getImageMenuOptions } = useImageMenuOptions()
const openOption = getImageMenuOptions(node).find(
(o) => o.label === 'Open Image'
)
openOption?.action?.()
expect(openFileInNewTab).toHaveBeenCalledWith(
'http://localhost/test.png?foo=bar'
)
})
it('saves the selected image without preview query params', () => {
const node = createImageNode()
node.imgs![0].src = 'http://localhost/test.png?preview=1&foo=bar'
const { getImageMenuOptions } = useImageMenuOptions()
const saveOption = getImageMenuOptions(node).find(
(o) => o.label === 'Save Image'
)
saveOption?.action?.()
expect(downloadFile).toHaveBeenCalledWith(
'http://localhost/test.png?foo=bar'
)
})
it('does not open or save when the active image is missing', () => {
const node = createImageNode({ imageIndex: 1 })
const { getImageMenuOptions } = useImageMenuOptions()
const options = getImageMenuOptions(node)
const openOption = options.find((o) => o.label === 'Open Image')
const saveOption = options.find((o) => o.label === 'Save Image')
expect(openOption?.action).toEqual(expect.any(Function))
expect(saveOption?.action).toEqual(expect.any(Function))
openOption?.action?.()
saveOption?.action?.()
expect(openFileInNewTab).not.toHaveBeenCalled()
expect(downloadFile).not.toHaveBeenCalled()
})
it('logs save failures for invalid image URLs', () => {
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const node = createImageNode()
Object.defineProperty(node.imgs![0], 'src', {
value: 'http://[',
configurable: true
})
const { getImageMenuOptions } = useImageMenuOptions()
getImageMenuOptions(node)
.find((o) => o.label === 'Save Image')
?.action?.()
expect(errorSpy).toHaveBeenCalledWith(
'Failed to save image:',
expect.any(TypeError)
)
expect(downloadFile).not.toHaveBeenCalled()
})
it('copies the selected image to clipboard', async () => {
const node = createImageNode()
const drawImage = vi.fn()
const write = vi.fn().mockResolvedValue(undefined)
stubClipboardItem()
mockClipboard(fromPartial<Clipboard>({ write }))
vi.spyOn(HTMLCanvasElement.prototype, 'getContext').mockImplementation(
(() =>
fromPartial<CanvasRenderingContext2D>({
drawImage
})) as unknown as HTMLCanvasElement['getContext']
)
vi.spyOn(HTMLCanvasElement.prototype, 'toBlob').mockImplementation(
(callback: BlobCallback) => {
callback(new Blob(['image'], { type: 'image/png' }))
}
)
const { getImageMenuOptions } = useImageMenuOptions()
await getImageMenuOptions(node)
.find((o) => o.label === 'Copy Image')
?.action?.()
expect(drawImage).toHaveBeenCalledWith(node.imgs![0], 0, 0)
expect(write).toHaveBeenCalledWith([
expect.objectContaining({
items: { 'image/png': expect.any(Blob) }
})
])
})
it('does not copy when canvas context is unavailable', async () => {
const node = createImageNode()
const write = vi.fn()
mockClipboard(fromPartial<Clipboard>({ write }))
vi.spyOn(HTMLCanvasElement.prototype, 'getContext').mockImplementation(
(() => null) as HTMLCanvasElement['getContext']
)
const { getImageMenuOptions } = useImageMenuOptions()
await getImageMenuOptions(node)
.find((o) => o.label === 'Copy Image')
?.action?.()
expect(write).not.toHaveBeenCalled()
})
it('does not copy when canvas blob creation fails', async () => {
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
const node = createImageNode()
const write = vi.fn()
mockClipboard(fromPartial<Clipboard>({ write }))
vi.spyOn(HTMLCanvasElement.prototype, 'getContext').mockImplementation(
(() =>
fromPartial<CanvasRenderingContext2D>({
drawImage: vi.fn()
})) as unknown as HTMLCanvasElement['getContext']
)
vi.spyOn(HTMLCanvasElement.prototype, 'toBlob').mockImplementation(
(callback: BlobCallback) => {
callback(null)
}
)
const { getImageMenuOptions } = useImageMenuOptions()
await getImageMenuOptions(node)
.find((o) => o.label === 'Copy Image')
?.action?.()
expect(warnSpy).toHaveBeenCalledWith('Failed to create image blob')
expect(write).not.toHaveBeenCalled()
})
})
})

View File

@@ -1,315 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { createApp, defineComponent, h, nextTick } from 'vue'
import type { App as VueApp } from 'vue'
import { useNodeBadge } from '@/composables/node/useNodeBadge'
import { BadgePosition, LGraphNode } from '@/lib/litegraph/src/litegraph'
import type { LGraphBadge } from '@/lib/litegraph/src/litegraph'
import type { ComfyExtension } from '@/types/comfy'
import { toNodeId } from '@/types/nodeId'
import { NodeBadgeMode } from '@/types/nodeSource'
const {
settings,
appState,
extensionState,
nodeDefState,
pricingState,
setDirtyMock,
addEventListenerMock,
registerExtensionMock,
getCreditsBadgeMock,
updateSubgraphCreditsMock,
getNodePricingConfigMock,
getNodeDisplayPriceMock,
getRelevantWidgetNamesMock,
triggerPriceRecalculationMock,
useComputedWithWidgetWatchMock
} = vi.hoisted(() => ({
settings: {} as Record<string, unknown>,
appState: {
graph: {
nodes: [] as unknown[]
}
},
extensionState: {
installed: false,
registered: undefined as ComfyExtension | undefined
},
nodeDefState: {
value: null as Record<string, unknown> | null
},
pricingState: {
revision: { value: 0 },
config: undefined as
| {
depends_on?: {
widgets?: string[]
inputs?: string[]
input_groups?: string[]
}
}
| undefined,
label: '1 credit'
},
setDirtyMock: vi.fn(),
addEventListenerMock: vi.fn(),
registerExtensionMock: vi.fn((extension: ComfyExtension) => {
extensionState.registered = extension
}),
getCreditsBadgeMock: vi.fn((text: string) => ({ text })),
updateSubgraphCreditsMock: vi.fn(),
getNodePricingConfigMock: vi.fn(() => pricingState.config),
getNodeDisplayPriceMock: vi.fn(() => pricingState.label),
getRelevantWidgetNamesMock: vi.fn(() => ['seed']),
triggerPriceRecalculationMock: vi.fn(),
useComputedWithWidgetWatchMock: vi.fn(() => vi.fn())
}))
vi.mock('@/scripts/app', () => ({
app: {
canvas: {
setDirty: setDirtyMock,
canvas: {
addEventListener: addEventListenerMock
},
graph: appState.graph
}
}
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => ({
get: (key: string) => settings[key]
})
}))
vi.mock('@/stores/extensionStore', () => ({
useExtensionStore: () => ({
isExtensionInstalled: () => extensionState.installed,
registerExtension: registerExtensionMock
})
}))
vi.mock('@/stores/nodeDefStore', () => ({
useNodeDefStore: () => ({
fromLGraphNode: () => nodeDefState.value
})
}))
vi.mock('@/stores/workspace/colorPaletteStore', () => ({
useColorPaletteStore: () => ({
completedActivePalette: {
colors: {
litegraph_base: {
BADGE_FG_COLOR: '#fff',
BADGE_BG_COLOR: '#000'
}
}
}
})
}))
vi.mock('@/composables/node/useNodePricing', () => ({
useNodePricing: () => ({
pricingRevision: pricingState.revision,
getNodePricingConfig: getNodePricingConfigMock,
getNodeDisplayPrice: getNodeDisplayPriceMock,
getRelevantWidgetNames: getRelevantWidgetNamesMock,
triggerPriceRecalculation: triggerPriceRecalculationMock
})
}))
vi.mock('@/composables/node/usePriceBadge', () => ({
usePriceBadge: () => ({
getCreditsBadge: getCreditsBadgeMock,
updateSubgraphCredits: updateSubgraphCreditsMock
})
}))
vi.mock('@/composables/node/useWatchWidget', () => ({
useComputedWithWidgetWatch: useComputedWithWidgetWatchMock
}))
class ApiNode extends LGraphNode {
static override nodeData = { name: 'ApiNode', api_node: true }
}
function mountBadge(): VueApp {
const app = createApp(
defineComponent({
setup() {
useNodeBadge()
return () => h('div')
}
})
)
app.mount(document.createElement('div'))
return app
}
function registeredExtension(): ComfyExtension {
if (!extensionState.registered)
throw new Error('Missing registered extension')
return extensionState.registered
}
function comfyApp(): Parameters<NonNullable<ComfyExtension['init']>>[0] {
return {} as Parameters<NonNullable<ComfyExtension['init']>>[0]
}
function callNodeCreated(node: LGraphNode) {
registeredExtension().nodeCreated?.(node, comfyApp())
}
function inputSlot(name: string) {
return new LGraphNode('slot').addInput(name, '*')
}
function defaultSettings() {
settings['Comfy.NodeBadge.NodeSourceBadgeMode'] = NodeBadgeMode.None
settings['Comfy.NodeBadge.NodeIdBadgeMode'] = NodeBadgeMode.None
settings['Comfy.NodeBadge.NodeLifeCycleBadgeMode'] = NodeBadgeMode.None
settings['Comfy.NodeBadge.ShowApiPricing'] = false
}
describe('useNodeBadge', () => {
let mountedApp: VueApp | undefined
beforeEach(() => {
defaultSettings()
extensionState.installed = false
extensionState.registered = undefined
appState.graph.nodes = []
nodeDefState.value = null
pricingState.revision.value = 0
pricingState.config = undefined
pricingState.label = '1 credit'
setDirtyMock.mockClear()
addEventListenerMock.mockClear()
registerExtensionMock.mockClear()
getCreditsBadgeMock.mockClear()
updateSubgraphCreditsMock.mockClear()
getNodePricingConfigMock.mockClear()
getNodeDisplayPriceMock.mockClear()
getRelevantWidgetNamesMock.mockClear()
triggerPriceRecalculationMock.mockClear()
useComputedWithWidgetWatchMock.mockClear()
})
afterEach(() => {
mountedApp?.unmount()
mountedApp = undefined
})
it('does not register the badge extension twice', async () => {
extensionState.installed = true
mountedApp = mountBadge()
await nextTick()
expect(registerExtensionMock).not.toHaveBeenCalled()
})
it('adds the configured node identity badge', async () => {
settings['Comfy.NodeBadge.NodeSourceBadgeMode'] = NodeBadgeMode.ShowAll
settings['Comfy.NodeBadge.NodeIdBadgeMode'] = NodeBadgeMode.ShowAll
settings['Comfy.NodeBadge.NodeLifeCycleBadgeMode'] =
NodeBadgeMode.HideBuiltIn
nodeDefState.value = {
isCoreNode: false,
nodeLifeCycleBadgeText: 'Beta',
nodeSource: { badgeText: 'Pack' }
}
const node = new LGraphNode('Test')
node.id = toNodeId('7')
mountedApp = mountBadge()
await nextTick()
callNodeCreated(node)
const badge = node.badges[0] as () => LGraphBadge
expect(node.badgePosition).toBe(BadgePosition.TopRight)
expect(badge().text).toBe('#7 Beta Pack')
})
it('hides built-in badge text when the mode excludes core nodes', async () => {
settings['Comfy.NodeBadge.NodeSourceBadgeMode'] = NodeBadgeMode.HideBuiltIn
settings['Comfy.NodeBadge.NodeIdBadgeMode'] = NodeBadgeMode.ShowAll
settings['Comfy.NodeBadge.NodeLifeCycleBadgeMode'] =
NodeBadgeMode.HideBuiltIn
nodeDefState.value = {
isCoreNode: true,
nodeLifeCycleBadgeText: 'Core',
nodeSource: { badgeText: 'Built-in' }
}
const node = new LGraphNode('Core')
node.id = toNodeId('11')
mountedApp = mountBadge()
await nextTick()
callNodeCreated(node)
const badge = node.badges[0] as () => LGraphBadge
expect(badge().text).toBe('#11')
})
it('adds dynamic API pricing badges and refreshes relevant input changes', async () => {
settings['Comfy.NodeBadge.ShowApiPricing'] = true
pricingState.config = {
depends_on: {
widgets: ['seed'],
inputs: ['image'],
input_groups: ['lora']
}
}
const originalOnConnectionsChange = vi.fn()
const node = new ApiNode('API')
node.onConnectionsChange = originalOnConnectionsChange
mountedApp = mountBadge()
await nextTick()
callNodeCreated(node)
expect(useComputedWithWidgetWatchMock).toHaveBeenCalledWith(node, {
widgetNames: ['seed'],
triggerCanvasRedraw: true
})
expect(getCreditsBadgeMock).toHaveBeenCalledWith('1 credit')
const priceBadge = node.badges[1] as () => { text: string }
expect(priceBadge().text).toBe('1 credit')
pricingState.label = '2 credits'
expect(priceBadge().text).toBe('2 credits')
node.onConnectionsChange?.(1, 0, true, undefined, inputSlot('image'))
node.onConnectionsChange?.(1, 0, true, undefined, inputSlot('lora.0'))
node.onConnectionsChange?.(1, 0, true, undefined, inputSlot('clip'))
node.onConnectionsChange?.(1, 0, true, undefined, inputSlot(''))
expect(originalOnConnectionsChange).toHaveBeenCalledTimes(4)
expect(triggerPriceRecalculationMock).toHaveBeenCalledTimes(2)
expect(triggerPriceRecalculationMock).toHaveBeenCalledWith(node)
})
it('updates subgraph credit badges from registered extension hooks', async () => {
const nodes = [new LGraphNode('one'), new LGraphNode('two')]
appState.graph.nodes = nodes
mountedApp = mountBadge()
await nextTick()
await registeredExtension().init?.(comfyApp())
await registeredExtension().afterConfigureGraph?.([], comfyApp())
const setGraphHandler = addEventListenerMock.mock.calls.find(
([event]) => event === 'litegraph:set-graph'
)?.[1]
const convertedHandler = addEventListenerMock.mock.calls.find(
([event]) => event === 'subgraph-converted'
)?.[1]
setGraphHandler?.()
convertedHandler?.({ detail: { subgraphNode: nodes[0] } })
expect(updateSubgraphCreditsMock).toHaveBeenCalledWith(nodes[0])
expect(updateSubgraphCreditsMock).toHaveBeenCalledWith(nodes[1])
})
})

View File

@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}

View File

@@ -1,102 +0,0 @@
import { ref } from 'vue'
import { describe, expect, it } from 'vitest'
import { useTreeExpansion } from '@/composables/useTreeExpansion'
import type { TreeNode } from '@/types/treeExplorerTypes'
function node(over: Partial<TreeNode>): TreeNode {
return over as TreeNode
}
// root ─┬─ a ── a1 (leaf)
// └─ b (leaf)
function sampleTree() {
const a1 = node({ key: 'a1', leaf: true })
const a = node({ key: 'a', leaf: false, children: [a1] })
const b = node({ key: 'b', leaf: true })
const root = node({ key: 'root', leaf: false, children: [a, b] })
return { root, a, a1, b }
}
describe('useTreeExpansion', () => {
it('toggleNode adds then removes a node key', () => {
const expandedKeys = ref<Record<string, boolean>>({})
const { toggleNode } = useTreeExpansion(expandedKeys)
const n = node({ key: 'x' })
toggleNode(n)
expect(expandedKeys.value).toEqual({ x: true })
toggleNode(n)
expect(expandedKeys.value).toEqual({})
})
it('toggleNode ignores nodes without a string key', () => {
const expandedKeys = ref<Record<string, boolean>>({})
const { toggleNode } = useTreeExpansion(expandedKeys)
toggleNode(node({ key: undefined }))
toggleNode(node({ key: 42 as unknown as string }))
expect(expandedKeys.value).toEqual({})
})
it('expandNode expands the node and all non-leaf descendants only', () => {
const expandedKeys = ref<Record<string, boolean>>({})
const { expandNode } = useTreeExpansion(expandedKeys)
const { root } = sampleTree()
expandNode(root)
// root and a are folders; a1 and b are leaves and must be skipped
expect(expandedKeys.value).toEqual({ root: true, a: true })
})
it('expandNode does nothing for a leaf node', () => {
const expandedKeys = ref<Record<string, boolean>>({})
const { expandNode } = useTreeExpansion(expandedKeys)
expandNode(node({ key: 'leaf', leaf: true }))
expect(expandedKeys.value).toEqual({})
})
it('collapseNode removes the node and its non-leaf descendants', () => {
const expandedKeys = ref<Record<string, boolean>>({
root: true,
a: true,
stray: true
})
const { collapseNode } = useTreeExpansion(expandedKeys)
const { root } = sampleTree()
collapseNode(root)
expect(expandedKeys.value).toEqual({ stray: true })
})
it('toggleNodeRecursive expands when collapsed and collapses when expanded', () => {
const expandedKeys = ref<Record<string, boolean>>({})
const { toggleNodeRecursive } = useTreeExpansion(expandedKeys)
const { root } = sampleTree()
toggleNodeRecursive(root)
expect(expandedKeys.value).toEqual({ root: true, a: true })
toggleNodeRecursive(root)
expect(expandedKeys.value).toEqual({})
})
it('toggleNodeOnEvent toggles recursively with ctrl and singly without', () => {
const expandedKeys = ref<Record<string, boolean>>({})
const { toggleNodeOnEvent } = useTreeExpansion(expandedKeys)
const { root } = sampleTree()
toggleNodeOnEvent(new KeyboardEvent('keydown', { ctrlKey: true }), root)
expect(expandedKeys.value).toEqual({ root: true, a: true })
// Plain toggle removes only the node's own key, leaving descendants
toggleNodeOnEvent(new MouseEvent('click'), root)
expect(expandedKeys.value).toEqual({ a: true })
})
})

View File

@@ -0,0 +1,105 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
import { LGraph, LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph'
import { createMockCanvasRenderingContext2D } from '@/utils/__tests__/litegraphTestUtils'
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
import { useViewErrorsInGraph } from './useViewErrorsInGraph'
const apiMock = vi.hoisted(() => ({
getSettings: vi.fn(),
storeSetting: vi.fn(),
storeSettings: vi.fn()
}))
vi.mock('@/scripts/api', () => ({
api: apiMock
}))
const appMock = vi.hoisted(() => ({
ui: {
settings: {
dispatchChange: vi.fn()
}
},
rootGraph: {
events: new EventTarget(),
nodes: []
}
}))
vi.mock('@/scripts/app', () => ({
app: appMock
}))
function createSelectedCanvas() {
const graph = new LGraph()
const canvasElement = document.createElement('canvas')
canvasElement.width = 800
canvasElement.height = 600
canvasElement.getContext = vi
.fn()
.mockReturnValue(createMockCanvasRenderingContext2D())
const canvas = new LGraphCanvas(canvasElement, graph, {
skip_events: true,
skip_render: true
})
const node = new LGraphNode('Selected Node')
graph.add(node)
canvas.selectedItems.add(node)
node.selected = true
return { canvas, node }
}
describe('useViewErrorsInGraph', () => {
beforeEach(() => {
vi.clearAllMocks()
setActivePinia(createPinia())
apiMock.getSettings.mockResolvedValue({})
apiMock.storeSetting.mockResolvedValue(undefined)
apiMock.storeSettings.mockResolvedValue(undefined)
})
it('opens graph errors and clears app-mode error UI state', () => {
const canvasStore = useCanvasStore()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
const workflowStore = useWorkflowStore()
const { canvas, node } = createSelectedCanvas()
workflowStore.activeWorkflow = {
activeMode: 'app'
} as typeof workflowStore.activeWorkflow
canvasStore.canvas = canvas
canvasStore.selectedItems = [node]
executionErrorStore.showErrorOverlay()
useViewErrorsInGraph().viewErrorsInGraph()
expect(node.selected).toBe(false)
expect(canvasStore.linearMode).toBe(false)
expect(canvasStore.selectedItems).toEqual([])
expect(rightSidePanelStore.activeTab).toBe('errors')
expect(rightSidePanelStore.isOpen).toBe(true)
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
})
it('opens graph errors when the canvas is not initialized', () => {
const canvasStore = useCanvasStore()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
canvasStore.canvas = null
executionErrorStore.showErrorOverlay()
expect(() => useViewErrorsInGraph().viewErrorsInGraph()).not.toThrow()
expect(rightSidePanelStore.activeTab).toBe('errors')
expect(rightSidePanelStore.isOpen).toBe(true)
expect(executionErrorStore.isErrorOverlayOpen).toBe(false)
})
})

View File

@@ -0,0 +1,22 @@
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
export function useViewErrorsInGraph() {
const canvasStore = useCanvasStore()
const executionErrorStore = useExecutionErrorStore()
const rightSidePanelStore = useRightSidePanelStore()
function viewErrorsInGraph() {
canvasStore.linearMode = false
if (canvasStore.canvas) {
canvasStore.canvas.deselectAll()
canvasStore.updateSelectedItems()
}
rightSidePanelStore.openPanel('errors')
executionErrorStore.dismissErrorOverlay()
}
return { viewErrorsInGraph }
}

View File

@@ -1,47 +1,14 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { describe, expect, it, vi } from 'vitest'
import type { SerialisedLLinkArray } from '@/lib/litegraph/src/LLink'
import { LGraphNode, LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { ComfyNode } from '@/platform/workflow/validation/schemas/workflowSchema'
import type { ComfyNodeDef } from '@/schemas/nodeDefSchema'
import type { ComfyApp } from '@/scripts/app'
import { useNodeDefStore } from '@/stores/nodeDefStore'
import type { ComfyExtension } from '@/types/comfy'
import type { GroupNodeWorkflowData } from './groupNode'
const appMock = vi.hoisted(() => ({
canvas: {
emitAfterChange: vi.fn(),
emitBeforeChange: vi.fn(),
selected_nodes: {}
},
registerExtension: vi.fn(),
registerNodeDef: vi.fn(),
rootGraph: {
convertToSubgraph: vi.fn(),
extra: {},
getNodeById: vi.fn(),
links: {},
nodes: [],
remove: vi.fn()
}
}))
const widgetStoreMock = vi.hoisted(() => ({
inputIsWidget: vi.fn((spec: unknown[]) =>
['BOOLEAN', 'COMBO', 'FLOAT', 'INT', 'STRING'].includes(String(spec[0]))
)
}))
vi.mock('@/scripts/app', () => ({
app: appMock
}))
vi.mock('@/stores/widgetStore', () => ({
useWidgetStore: () => widgetStoreMock
app: {
registerExtension: vi.fn()
}
}))
import { GroupNodeConfig, replaceLegacySeparators } from './groupNode'
@@ -59,42 +26,6 @@ function makeNode(type: string): ComfyNode {
}
}
function makeNodeDef(overrides: Partial<ComfyNodeDef> = {}): ComfyNodeDef {
return {
name: 'TestNode',
display_name: 'Test Node',
description: '',
category: 'test',
input: { required: {}, optional: {} },
output: [],
output_name: [],
output_is_list: [],
output_node: false,
python_module: 'test',
...overrides
} as ComfyNodeDef
}
function extension(): ComfyExtension {
const groupExtension = appMock.registerExtension.mock.calls.find(
([registered]) => registered.name === 'Comfy.GroupNode'
)?.[0]
if (!groupExtension) throw new Error('GroupNode extension was not registered')
return groupExtension as ComfyExtension
}
function addCustomNodeDefs(defs: Record<string, ComfyNodeDef>) {
extension().addCustomNodeDefs?.(defs, appMock as unknown as ComfyApp)
}
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
appMock.registerNodeDef.mockReset()
widgetStoreMock.inputIsWidget.mockClear()
LiteGraph.registered_node_types = {}
addCustomNodeDefs({})
})
describe('replaceLegacySeparators', () => {
it('rewrites the legacy "workflow/" prefix to "workflow>"', () => {
const nodes = [makeNode('workflow/My Group')]
@@ -173,398 +104,4 @@ describe('GroupNodeConfig.getLinks', () => {
const config = configFrom([], [[0, 1, 'IMAGE']])
expect(config.externalFrom[0][1]).toBe('IMAGE')
})
it('ignores external links without a type and accumulates multiple slots', () => {
const config = configFrom(
[],
[
[0, 1, null as unknown as string],
[0, 2, 'LATENT'],
[0, 3, 'IMAGE']
]
)
expect(config.externalFrom[0]).toEqual({ 2: 'LATENT', 3: 'IMAGE' })
})
})
describe('GroupNodeConfig.getNodeDef', () => {
const imageNodeDef = makeNodeDef({
name: 'ImageNode',
input: {
required: {
image: ['IMAGE', {}],
mode: [['fast', 'slow'], {}]
},
optional: {
strength: ['FLOAT', { default: 1 }]
}
},
output: ['IMAGE'],
output_name: ['image'],
output_is_list: [false]
})
beforeEach(() => {
addCustomNodeDefs({ ImageNode: imageNodeDef })
})
it('returns registered definitions for normal node types', () => {
const config = new GroupNodeConfig('group', {
nodes: [{ index: 0, type: 'ImageNode' }],
links: [],
external: []
})
expect(config.getNodeDef({ index: 0, type: 'ImageNode' })).toBe(
imageNodeDef
)
})
it('returns undefined for nodes without an index or a known type', () => {
const config = new GroupNodeConfig('group', {
nodes: [{ type: 'UnknownNode' }],
links: [],
external: []
})
expect(config.getNodeDef({ type: 'UnknownNode' })).toBeUndefined()
})
it('skips unlinked primitive nodes', () => {
const config = new GroupNodeConfig('group', {
nodes: [{ index: 0, type: 'PrimitiveNode' }],
links: [],
external: []
})
expect(
config.getNodeDef({ index: 0, type: 'PrimitiveNode' })
).toBeUndefined()
})
it('derives primitive node type from the outgoing link type', () => {
const config = new GroupNodeConfig('group', {
nodes: [
{ index: 0, type: 'PrimitiveNode' },
{ index: 1, type: 'ImageNode' }
],
links: [[0, 0, 1, 0, 1, 'IMAGE'] as SerialisedLLinkArray],
external: []
})
expect(
config.getNodeDef({ index: 0, type: 'PrimitiveNode' })
).toMatchObject({
input: { required: { value: ['IMAGE', {}] } },
output: ['IMAGE']
})
})
it('falls back to null when primitive combo target spec is not primitive', () => {
const config = new GroupNodeConfig('group', {
nodes: [
{
index: 0,
type: 'PrimitiveNode',
outputs: [{ name: 'mode', widget: { name: 'mode' } }]
},
{ index: 1, type: 'ImageNode' }
],
links: [[0, 0, 1, 0, 1, 'COMBO'] as SerialisedLLinkArray],
external: []
})
expect(config.getNodeDef(config.nodeData.nodes[0])).toMatchObject({
input: { required: { value: [null, {}] } },
output: [null]
})
})
it('returns null for reroutes used only inside the group', () => {
const config = new GroupNodeConfig('group', {
nodes: [
{ index: 0, type: 'ImageNode' },
{ index: 1, type: 'Reroute' },
{ index: 2, type: 'ImageNode' }
],
links: [
[0, 0, 1, 0, 1, 'IMAGE'],
[1, 0, 2, 0, 2, 'IMAGE']
] as SerialisedLLinkArray[],
external: []
})
expect(config.getNodeDef({ index: 1, type: 'Reroute' })).toBeNull()
})
it('derives reroute type from outgoing target inputs', () => {
const config = new GroupNodeConfig('group', {
nodes: [
{ index: 0, type: 'Reroute' },
{
index: 1,
type: 'ImageNode',
inputs: [{ name: 'image', type: 'IMAGE' }]
}
],
links: [[0, 0, 1, 0, 1, 'IMAGE'] as SerialisedLLinkArray],
external: [[0, 0, 'IMAGE']]
})
expect(config.getNodeDef({ index: 0, type: 'Reroute' })).toMatchObject({
input: { required: { IMAGE: ['IMAGE', { forceInput: true }] } },
output: ['IMAGE']
})
})
it('derives reroute type from incoming output metadata', () => {
const config = new GroupNodeConfig('group', {
nodes: [
{ index: 0, type: 'ImageNode', outputs: [{ type: 'LATENT' }] },
{ index: 1, type: 'Reroute' }
],
links: [[0, 0, 1, 0, 1, 'LATENT'] as SerialisedLLinkArray],
external: [[1, 0, 'LATENT']]
})
expect(config.getNodeDef({ index: 1, type: 'Reroute' })).toMatchObject({
input: { required: { LATENT: ['LATENT', { forceInput: true }] } },
output: ['LATENT']
})
})
it('derives pipe reroute type from external metadata when links omit it', () => {
const config = new GroupNodeConfig('group', {
nodes: [{ index: 0, type: 'Reroute' }],
links: [],
external: [[0, 0, 'MASK']]
})
expect(config.getNodeDef({ index: 0, type: 'Reroute' })).toMatchObject({
input: { required: { MASK: ['MASK', { forceInput: true }] } },
output: ['MASK']
})
})
})
describe('GroupNodeConfig input and output mapping', () => {
function configWithNode(node: GroupNodeWorkflowData['nodes'][number]) {
const config = new GroupNodeConfig('group', {
nodes: [node],
links: [],
external: [],
config: {
0: {
input: {
hidden: { visible: false },
renamed: { name: 'Custom Name' }
},
output: {
1: { name: 'Custom Output' },
2: { visible: false }
}
}
}
})
config.nodeDef = makeNodeDef({
input: { required: {} },
output: [],
output_name: [],
output_is_list: []
})
return config
}
it('renames duplicate inputs and adds seed control metadata', () => {
const config = configWithNode({
index: 0,
type: 'Sampler',
title: 'Sampler A',
inputs: [{ name: 'seed', label: 'Seed Label' }]
})
const seenInputs = { seed: 1, 'Sampler A seed': 1 }
const result = config.getInputConfig(
{ index: 0, type: 'Sampler', title: 'Sampler A' },
'seed',
seenInputs,
['INT', {}]
)
expect(result.name).toBe('Sampler A 1 seed')
expect(result.config).toEqual([
'INT',
{ control_after_generate: 'Sampler A control_after_generate' }
])
})
it('maps image upload widget aliases through converted widget names', () => {
const config = configWithNode({ index: 0, type: 'LoadImage' })
config.oldToNewWidgetMap[0] = { customImage: 'Uploaded Image' }
expect(
config.getInputConfig({ index: 0, type: 'LoadImage' }, 'renamed', {}, [
'IMAGEUPLOAD',
{ widget: 'customImage' }
])
).toMatchObject({
name: 'Custom Name',
config: ['IMAGEUPLOAD', { widget: 'Uploaded Image' }]
})
})
it('splits widget inputs, socket inputs, and converted widget slots', () => {
const config = configWithNode({
index: 0,
type: 'MixedNode',
inputs: [{ name: 'mode', widget: { name: 'mode' } }]
})
const result = config.processWidgetInputs(
{
mode: ['COMBO', {}],
image: ['IMAGE', {}]
},
{
index: 0,
type: 'MixedNode',
inputs: [{ name: 'mode', widget: { name: 'mode' } }]
},
['mode', 'image'],
{}
)
expect(result.slots).toEqual(['image'])
expect(result.converted.get(0)).toBe('mode')
expect(config.oldToNewWidgetMap[0].mode).toBeNull()
})
it('adds visible unlinked input slots and skips hidden configured inputs', () => {
const config = configWithNode({
index: 0,
type: 'InputNode'
})
const inputMap: Record<number, number> = {}
config.processInputSlots(
{
image: ['IMAGE', {}],
hidden: ['LATENT', {}]
},
{ index: 0, type: 'InputNode' },
['image', 'hidden'],
{},
inputMap,
{}
)
expect(config.nodeDef?.input?.required).toEqual({ image: ['IMAGE', {}] })
expect(inputMap).toEqual({ 0: 0 })
})
it('adds output metadata, hides linked/internal outputs, and dedupes labels', () => {
const config = configWithNode({
index: 0,
type: 'OutputNode',
title: 'Output A',
outputs: [{ name: 'image', label: 'Rendered' }]
})
config.linksFrom[0] = {
0: [[0, 0, 1, 0, 1, 'IMAGE'] as SerialisedLLinkArray]
}
config.processNodeOutputs(
{ index: 0, type: 'OutputNode', title: 'Output A' },
{ Rendered: 1 },
{
input: { required: {} },
output: ['IMAGE', 'LATENT', 'MASK'],
output_name: ['image', 'latent', 'mask'],
output_is_list: [false, true, false]
}
)
expect(config.outputVisibility).toEqual([false, true, false])
expect(config.nodeDef?.output).toEqual(['LATENT'])
expect(config.nodeDef?.output_is_list).toEqual([true])
expect(config.nodeDef?.output_name).toEqual(['Custom Output'])
})
})
describe('GroupNodeConfig.registerFromWorkflow', () => {
it('adds missing type actions and skips registration for incomplete groups', async () => {
const groupNodes: Record<string, GroupNodeWorkflowData> = {
Broken: {
nodes: [{ index: 0, type: 'MissingNode' }],
links: [],
external: []
}
}
const missingNodeTypes: Parameters<
typeof GroupNodeConfig.registerFromWorkflow
>[1] = []
await GroupNodeConfig.registerFromWorkflow(groupNodes, missingNodeTypes)
expect(appMock.registerNodeDef).not.toHaveBeenCalled()
expect(missingNodeTypes).toHaveLength(2)
expect(missingNodeTypes[0]).toMatchObject({
type: 'MissingNode',
hint: " (In group node 'workflow>Broken')"
})
const action = missingNodeTypes[1]
if (typeof action === 'string') {
throw new Error('Expected a missing-node action entry, not a string')
}
const target = document.createElement('button')
const { callback } = action.action as {
callback: (event: MouseEvent) => void
}
const event = new MouseEvent('click')
Object.defineProperty(event, 'target', { value: target })
callback(event)
expect(groupNodes.Broken).toBeUndefined()
expect(target.textContent).toBe('Removed')
expect(target.style.pointerEvents).toBe('none')
})
it('registers complete group node types and stores their generated node defs', async () => {
addCustomNodeDefs({
ImageNode: makeNodeDef({
name: 'ImageNode',
input: { required: { image: ['IMAGE', {}] } },
output: ['IMAGE'],
output_name: ['image'],
output_is_list: [false]
})
})
LiteGraph.registered_node_types.ImageNode = class extends LGraphNode {}
await GroupNodeConfig.registerFromWorkflow(
{
Complete: {
nodes: [{ index: 0, type: 'ImageNode' }],
links: [],
external: [[0, 0, 'IMAGE']]
}
},
[]
)
expect(appMock.registerNodeDef).toHaveBeenCalledWith(
'workflow>Complete',
expect.objectContaining({
category: 'group nodes>workflow',
display_name: 'Complete',
name: 'workflow>Complete'
})
)
expect(useNodeDefStore().nodeDefsByName['workflow>Complete']).toEqual(
expect.objectContaining({
category: 'group nodes>workflow',
display_name: 'Complete',
name: 'workflow>Complete'
})
)
})
})

View File

@@ -25,6 +25,6 @@ function handleClose() {
}
function handleSubscribe() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
}
</script>

View File

@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
// Shows loading affordances
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
})
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
'team_700',
'yearly'
'yearly',
{ paymentIntentSource: 'deep_link' }
)
// Team never goes through the personal checkout path
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()

View File

@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
return
}
isTeamCheckout.value = true
await performTeamSubscriptionCheckout(stopId, billingCycle)
await performTeamSubscriptionCheckout(stopId, billingCycle, {
paymentIntentSource: 'deep_link'
})
return
}
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
if (isActiveSubscription.value) {
await accessBillingPortal(undefined, false)
} else {
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
openInNewTab: false,
paymentIntentSource: 'deep_link'
})
}
}, reportError)

View File

@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
})
function handleAddCredits() {
telemetry?.trackAddApiCreditButtonClicked()
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
void dialogService.showTopUpCreditsDialog()
}
function handleUpgradeToAddCredits() {
showPricingTable()
showPricingTable({ reason: 'upgrade_to_add_credits' })
}
async function handleWindowFocus() {

View File

@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import FreeTierDialogContent from './FreeTierDialogContent.vue'
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
}))
}))
function renderComponent() {
function renderComponent(props?: { reason?: PaymentIntentSource }) {
const i18n = createI18n({
legacy: false,
locale: 'en',
@@ -23,6 +25,7 @@ function renderComponent() {
})
return render(FreeTierDialogContent, {
props,
global: {
plugins: [i18n]
}
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
renderComponent()
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
it('keeps the generic copy for intent reasons outside the credits variants', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'subscribe_to_run' })
expect(
screen.getByText('Your credits refresh on Jul 15, 2026.')
).toBeInTheDocument()
})
it('swaps to the out-of-credits copy without the refresh line', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'out_of_credits' })
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
})

View File

@@ -52,7 +52,7 @@
</p>
<p
v-if="!reason || reason === 'subscription_required'"
v-if="!isCreditsBlockedVariant"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -65,10 +65,7 @@
</p>
<p
v-if="
(!reason || reason === 'subscription_required') &&
formattedRenewalDate
"
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -88,7 +85,7 @@
@click="$emit('upgrade')"
>
{{
reason === 'out_of_credits' || reason === 'top_up_blocked'
isCreditsBlockedVariant
? $t('subscription.freeTier.upgradeCta')
: $t('subscription.freeTier.subscribeCta')
}}
@@ -103,12 +100,12 @@ import { computed } from 'vue'
import Button from '@/components/ui/button/Button.vue'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
defineProps<{
reason?: SubscriptionDialogReason
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
defineEmits<{
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
})
const freeTierCredits = computed(() => getTierCredits('free'))
// Only these two variants replace the generic free-tier copy; any other
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
const isCreditsBlockedVariant = computed(
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
)
</script>

View File

@@ -261,6 +261,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('should use the latest userId value when it changes after mount', async () => {
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
})

View File

@@ -277,13 +277,19 @@ import type {
TierKey,
TierPricing
} from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import {
recordPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { useAuthStore } from '@/stores/authStore'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
@@ -321,6 +327,10 @@ interface PricingTierConfig {
isPopular?: boolean
}
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
const emit = defineEmits<{
chooseTeamWorkspace: []
}>()
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
} as const
const previousPlan = currentPlanDescriptor.value
const checkoutAttribution = await getCheckoutAttributionForCloud()
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
})
}
const beginCheckoutMetadata = userId.value
? {
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change' as const,
...(reason ? { payment_intent_source: reason } : {}),
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
}
: null
// Pass the target tier to create a deep link to subscription update confirmation
const checkoutTier = getCheckoutTier(
targetPlan.tierKey,
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
if (downgrade) {
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
await accessBillingPortal()
const didOpenPortal = await accessBillingPortal()
if (didOpenPortal && beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
}
} else {
const didOpenPortal = await accessBillingPortal(checkoutTier)
if (!didOpenPortal) {
return
}
recordPendingSubscriptionCheckoutAttempt({
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
payment_intent_source: reason,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
...(previousPlan
? { previous_cycle: previousPlan.billingCycle }
: {})
})
if (beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
beginCheckoutMetadata,
pendingAttempt
)
)
}
}
} else {
await performSubscriptionCheckout(
tierKey,
currentBillingCycle.value,
true
)
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
paymentIntentSource: reason
})
}
} finally {
isLoading.value = false

View File

@@ -56,7 +56,7 @@ const handleSubscribe = () => {
current_tier: tier.value?.toLowerCase()
})
isAwaitingStripeSubscription.value = true
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_now_button' })
}
onBeforeUnmount(() => {

View File

@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
trackRunButton({ subscribe_to_run: true })
}
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
}
</script>

View File

@@ -48,7 +48,9 @@
v-if="isActiveSubscription"
variant="primary"
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
@click="showSubscriptionDialog"
@click="
showSubscriptionDialog({ reason: 'settings_billing_panel' })
"
>
{{ $t('subscription.upgradePlan') }}
</Button>

View File

@@ -33,7 +33,11 @@
</i18n-t>
</div>
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
<PricingTable
:reason
class="flex-1"
@choose-team-workspace="handleChooseTeam"
/>
<!-- Contact and Enterprise Links -->
<div class="flex flex-col items-center gap-2">
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import { useCommandStore } from '@/stores/commandStore'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
const { onClose, reason, onChooseTeam } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
onChooseTeam?: () => void
}>()

View File

@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
)
return
case 'subscription':
void dialogService.showSubscriptionRequiredDialog()
void dialogService.showSubscriptionRequiredDialog({
reason: 'subscription_required'
})
return
case 'credits':
void dialogService.showTopUpCreditsDialog({

View File

@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
})
}))
const mockTrackSubscription = vi.hoisted(() => vi.fn())
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
describe('usePricingTableUrlLoader', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
reason: 'deep_link',
planMode: undefined
})
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
reason: 'deep_link'
})
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
})
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('denies, strips, and clears together when the user is not eligible', async () => {
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({
query: { other: 'param' }
})
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
)
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
'pricing'

View File

@@ -7,7 +7,6 @@ import {
mergePreservedQueryIntoQuery
} from '@/platform/navigation/preservedQueryManager'
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
import { useTelemetry } from '@/platform/telemetry'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
const planMode =
param === 'team' || param === 'personal' ? param : undefined
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
}

View File

@@ -15,7 +15,7 @@ import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import { useDialogService } from '@/services/dialogService'
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
})
}, reportError)
const showSubscriptionDialog = (options?: {
reason?: SubscriptionDialogReason
}) => {
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: subscriptionTier.value?.toLowerCase(),
reason: options?.reason
})
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
void showSubscriptionRequiredDialog(options)
}
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
await fetchSubscriptionStatus()
if (!isSubscribedOrIsNotCloud.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscription_required' })
}
}

View File

@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
}))
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
const {
mockIsCloud,
mockTrackHelpResourceClicked,
mockTrackAddApiCreditButtonClicked
} = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockTrackHelpResourceClicked: vi.fn()
mockTrackHelpResourceClicked: vi.fn(),
mockTrackAddApiCreditButtonClicked: vi.fn()
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () =>
mockIsCloud.value
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
? {
trackHelpResourceClicked: mockTrackHelpResourceClicked,
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
}
: null
}))
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
const { handleAddApiCredits } = useSubscriptionActions()
handleAddApiCredits()
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
source: 'settings_billing_panel'
})
})
})

View File

@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
})
const handleAddApiCredits = () => {
telemetry?.trackAddApiCreditButtonClicked({
source: 'settings_billing_panel'
})
void dialogService.showTopUpCreditsDialog()
}

View File

@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
const mockCloseDialog = vi.fn()
const mockShowLayoutDialog = vi.fn()
const mockShowTeamWorkspacesDialog = vi.fn()
const mockTrackSubscription = vi.hoisted(() => vi.fn())
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
const mockIsCloud = vi.hoisted(() => ({ value: true }))
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isFreeTier: mockIsFreeTier,
isLegacyTeamPlan: mockIsLegacyTeamPlan
isLegacyTeamPlan: mockIsLegacyTeamPlan,
tier: mockTier
})
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
useWorkspaceUI: () => ({
permissions: {
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
mockIsCloud.value = true
mockIsInPersonalWorkspace.value = true
mockIsFreeTier.value = false
mockTier.value = 'FREE'
mockTeamWorkspacesEnabled.value = false
mockIsLegacyTeamPlan.value = false
mockCanManageSubscription.value = true
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
const props = mockShowLayoutDialog.mock.calls[0][0].props
expect(props.initialPlanMode).toBe('team')
})
it('tracks modal_opened with the caller reason and current tier', () => {
mockTier.value = 'STANDARD'
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'upgrade_to_add_credits' })
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
current_tier: 'standard',
reason: 'upgrade_to_add_credits'
})
})
it('tracks modal_opened on the workspace (unified) path too', () => {
mockTeamWorkspacesEnabled.value = true
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'subscribe_to_run' })
)
})
it('does not track modal_opened for the inactive member dialog', () => {
mockTeamWorkspacesEnabled.value = true
mockIsInPersonalWorkspace.value = false
mockCanManageSubscription.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('does not track on non-cloud', () => {
mockIsCloud.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
})
describe('show', () => {
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
expect.objectContaining({ key: 'subscription-required' })
)
})
it('tracks modal_opened with the reason for the free-tier dialog', () => {
mockIsFreeTier.value = true
mockIsInPersonalWorkspace.value = true
const { show } = useSubscriptionDialog()
show({ reason: 'out_of_credits' })
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'out_of_credits' })
)
})
})
describe('startTeamWorkspaceUpgradeFlow', () => {

View File

@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
export type SubscriptionDialogReason =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
interface SubscriptionDialogOptions {
reason?: SubscriptionDialogReason
export interface SubscriptionDialogOptions {
reason?: PaymentIntentSource
/**
* Forces the unified pricing dialog to open on a specific plan tab,
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
}
// Fired here — the choke point every paywall/pricing dialog variant passes
// through — so both the legacy and workspace billing paths emit it.
function trackModalOpened(reason?: PaymentIntentSource) {
// Resolved lazily to avoid the useBillingContext import cycle (see below).
const { tier } = useBillingContext()
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: tier.value?.toLowerCase(),
reason
})
}
function showPricingTable(options?: SubscriptionDialogOptions) {
if (!isCloud) return
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
return
}
trackModalOpened(options?.reason)
// Shared dialog shell styling for both variants.
const dialogComponentProps = {
style: 'width: min(1328px, 95vw); max-height: 958px;',
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
// (not at composable setup) to avoid the useBillingContext import cycle.
const { isFreeTier } = useBillingContext()
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
trackModalOpened(options?.reason)
const component = defineAsyncComponent(
() =>
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
sessionStorage.removeItem(RESUME_PRICING_KEY)
if (!workspaceStore.isInPersonalWorkspace) {
showPricingTable()
showPricingTable({ reason: 'team_upgrade_resume' })
}
} catch {
// sessionStorage may be unavailable

View File

@@ -0,0 +1,49 @@
import { beforeEach, describe, expect, it } from 'vitest'
import {
clearPendingSubscriptionCheckoutAttempt,
consumePendingSubscriptionCheckoutSuccess,
recordPendingSubscriptionCheckoutAttempt
} from './subscriptionCheckoutTracker'
const activeProStatus = {
is_active: true,
subscription_tier: 'PRO',
subscription_duration: 'MONTHLY'
} as const
describe('subscriptionCheckoutTracker', () => {
beforeEach(() => {
clearPendingSubscriptionCheckoutAttempt()
})
it('round-trips payment_intent_source from attempt to success metadata', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).toMatchObject({
tier: 'pro',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
})
it('omits payment_intent_source when the attempt had none', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).not.toBeNull()
expect(metadata).not.toHaveProperty('payment_intent_source')
})
})

View File

@@ -7,7 +7,12 @@ import type {
TierKey
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
import type {
BeginCheckoutMetadata,
PaymentIntentSource,
SubscriptionCheckoutType,
SubscriptionSuccessMetadata
} from '@/platform/telemetry/types'
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
const VALID_TIER_KEYS = new Set<TierKey>([
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
'comfy:subscription-checkout-attempt-changed'
type CheckoutType = 'new' | 'change'
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
interface SubscriptionStatusSnapshot {
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
subscription_duration?: SubscriptionDuration | null
}
interface PendingSubscriptionCheckoutAttempt {
export interface PendingSubscriptionCheckoutAttempt {
attempt_id: string
started_at_ms: number
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
interface RecordPendingSubscriptionCheckoutAttemptInput {
interface PendingSubscriptionCheckoutAttemptInput {
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
const dispatchPendingCheckoutChangeEvent = () => {
@@ -168,6 +174,9 @@ const normalizeAttempt = (
...(candidate.previous_cycle === 'monthly' ||
candidate.previous_cycle === 'yearly'
? { previous_cycle: candidate.previous_cycle }
: {}),
...(typeof candidate.payment_intent_source === 'string'
? { payment_intent_source: candidate.payment_intent_source }
: {})
}
}
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
getPendingSubscriptionCheckoutAttempt() !== null
export const recordPendingSubscriptionCheckoutAttempt = (
input: RecordPendingSubscriptionCheckoutAttemptInput
export const createPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
const attempt: PendingSubscriptionCheckoutAttempt = {
return {
attempt_id: createAttemptId(),
started_at_ms: Date.now(),
tier: input.tier,
cycle: input.cycle,
checkout_type: input.checkout_type,
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
...(input.payment_intent_source
? { payment_intent_source: input.payment_intent_source }
: {})
}
}
export const persistPendingSubscriptionCheckoutAttempt = (
attempt: PendingSubscriptionCheckoutAttempt
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
if (!storage) {
return attempt
}
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
return attempt
}
export const recordPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt =>
persistPendingSubscriptionCheckoutAttempt(
createPendingSubscriptionCheckoutAttempt(input)
)
export const withPendingCheckoutAttemptId = (
metadata: BeginCheckoutMetadata,
attempt: PendingSubscriptionCheckoutAttempt
): BeginCheckoutMetadata => ({
...metadata,
checkout_attempt_id: attempt.attempt_id
})
const didAttemptSucceed = (
attempt: PendingSubscriptionCheckoutAttempt,
status: SubscriptionStatusSnapshot
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
cycle: attempt.cycle,
checkout_type: attempt.checkout_type,
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
...(attempt.payment_intent_source
? { payment_intent_source: attempt.payment_intent_source }
: {}),
value,
currency: 'USD',
ecommerce: {

View File

@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'yearly', true)
await performSubscriptionCheckout('pro', 'yearly')
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-123',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new',
checkout_attempt_id: expect.any(String),
ga_client_id: 'ga-client-id',
ga_session_id: 'ga-session-id',
ga_session_number: 'ga-session-number',
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
gbraid: 'gbraid-456',
wbraid: 'wbraid-789'
})
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
JSON.parse(storedAttempt).attempt_id
)
expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining(
'/customers/cloud-subscription-checkout/pro-yearly'
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(warnSpy).toHaveBeenCalledWith(
'[SubscriptionCheckout] Failed to collect checkout attribution',
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-123',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
.spyOn(window, 'open')
.mockImplementation(() => window as unknown as Window)
vi.mocked(global.fetch).mockResolvedValue({
ok: true,
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', {
paymentIntentSource: 'out_of_credits'
})
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
)
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
const pendingAttempt = JSON.parse(storedAttempt)
expect(pendingAttempt).toMatchObject({
payment_intent_source: 'out_of_credits'
})
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
pendingAttempt.attempt_id
)
openSpy.mockRestore()
})
it('uses the latest userId when it changes after checkout starts', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
mockUserId.value = 'user-late'
authHeader.resolve({ Authorization: 'Bearer test-token' })
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-late',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
)
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
const storedAttempt = window.localStorage.getItem(
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
)
expect(storedAttempt).toBeNull()
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
checkout_attempt_id: expect.any(String)
})
)
})
})

View File

@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { getComfyApiBaseUrl } from '@/config/comfyApi'
import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import {
createPendingSubscriptionCheckoutAttempt,
persistPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import type { BillingCycle } from './subscriptionTierRank'
type CheckoutTier = TierKey | `${TierKey}-yearly`
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
return getCheckoutAttribution()
}
interface PerformSubscriptionCheckoutOptions {
openInNewTab?: boolean
paymentIntentSource?: PaymentIntentSource
}
/**
* Core subscription checkout logic shared between PricingTable and
* SubscriptionRedirectView. Handles:
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
export async function performSubscriptionCheckout(
tierKey: TierKey,
currentBillingCycle: BillingCycle,
openInNewTab: boolean = true
options: PerformSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
const { openInNewTab = true, paymentIntentSource } = options
const authStore = useAuthStore()
const { userId } = storeToRefs(authStore)
const telemetry = useTelemetry()
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
const data = await response.json()
if (data.checkout_url) {
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
payment_intent_source: paymentIntentSource
})
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...checkoutAttribution
})
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
{
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {}),
...checkoutAttribution
},
pendingAttempt
)
)
}
if (openInNewTab) {
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
if (!checkoutWindow) {
return
}
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
} else {
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
globalThis.location.href = data.checkout_url
}
}

View File

@@ -1,9 +1,13 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed, reactive } from 'vue'
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn()
}))
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
workspaceApi: { subscribe: mockSubscribe }
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
billing_op_id: 'op_1'
})
await performTeamSubscriptionCheckout('team_700', 'yearly')
await performTeamSubscriptionCheckout('team_700', 'yearly', {
paymentIntentSource: 'deep_link'
})
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
returnUrl: 'https://app.test/payment/success',
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
teamCreditStopId: 'team_700'
})
expect(assignedHref).toBe('https://stripe.test/pay')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'team',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op_1',
payment_intent_source: 'deep_link'
})
})
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
expect(assignedHref).toBeUndefined()
})
it('does not track begin_checkout when subscribe fails', async () => {
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
await expect(
performTeamSubscriptionCheckout('team_700', 'yearly')
).rejects.toThrow('subscribe failed')
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('does nothing off cloud', async () => {
mockIsCloud.value = false

View File

@@ -1,10 +1,16 @@
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
import { isCloud } from '@/platform/distribution/types'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
import type { BillingCycle } from './subscriptionTierRank'
interface PerformTeamSubscriptionCheckoutOptions {
paymentIntentSource?: PaymentIntentSource
}
/**
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
* link: subscribes to the per-credit Team plan at the chosen slider stop and
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
*/
export async function performTeamSubscriptionCheckout(
teamCreditStopId: string,
billingCycle: BillingCycle
billingCycle: BillingCycle,
options: PerformTeamSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
teamCreditStopId
})
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType: 'new',
billingOpId: response.billing_op_id,
paymentIntentSource: options.paymentIntentSource
})
if (response.status === 'needs_payment_method') {
// A needs_payment_method response without a URL is unusable: surface it to
// the caller's error handling rather than silently dropping the user home

View File

@@ -1,6 +1,7 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
import {
getSettingInfo,
@@ -10,47 +11,31 @@ import type { SettingTreeNode } from '@/platform/settings/settingStore'
import { useSettingUI } from './useSettingUI'
const { auth, billing, dist, featureFlags, vueFlags } = vi.hoisted(() => ({
auth: { isLoggedIn: { value: false } },
billing: { isActiveSubscription: { value: false } },
dist: { isCloud: false, isDesktop: false },
featureFlags: { teamWorkspacesEnabled: false, userSecretsEnabled: false },
vueFlags: { shouldRenderVueNodes: { value: false } }
}))
vi.mock('vue-i18n', () => ({
useI18n: () => ({ t: (_: string, fallback: string) => fallback })
}))
vi.mock('@/composables/auth/useCurrentUser', () => ({
useCurrentUser: () => ({ isLoggedIn: auth.isLoggedIn })
useCurrentUser: () => ({ isLoggedIn: ref(false) })
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isActiveSubscription: billing.isActiveSubscription
})
useBillingContext: () => ({ isActiveSubscription: ref(false) })
}))
vi.mock('@/composables/useFeatureFlags', () => ({
useFeatureFlags: () => ({
flags: featureFlags
flags: { teamWorkspacesEnabled: false, userSecretsEnabled: false }
})
}))
vi.mock('@/composables/useVueFeatureFlags', () => ({
useVueFeatureFlags: () => ({
shouldRenderVueNodes: vueFlags.shouldRenderVueNodes
})
useVueFeatureFlags: () => ({ shouldRenderVueNodes: ref(false) })
}))
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
return dist.isCloud
},
get isDesktop() {
return dist.isDesktop
}
isCloud: false,
isDesktop: false
}))
vi.mock('@/platform/settings/settingStore', () => ({
@@ -64,7 +49,6 @@ interface MockSettingParams {
type: string
defaultValue: unknown
category?: string[]
hideInVueNodes?: boolean
}
describe('useSettingUI', () => {
@@ -88,23 +72,13 @@ describe('useSettingUI', () => {
defaultValue: 'dark'
}
}
let settingsById: Record<string, MockSettingParams>
beforeEach(() => {
setActivePinia(createTestingPinia())
vi.clearAllMocks()
auth.isLoggedIn.value = false
billing.isActiveSubscription.value = false
dist.isCloud = false
dist.isDesktop = false
featureFlags.teamWorkspacesEnabled = false
featureFlags.userSecretsEnabled = false
vueFlags.shouldRenderVueNodes.value = false
Object.assign(window, { __CONFIG__: {} })
settingsById = mockSettings
vi.mocked(useSettingStore).mockReturnValue({
settingsById
settingsById: mockSettings
} as ReturnType<typeof useSettingStore>)
vi.mocked(getSettingInfo).mockImplementation((setting) => {
@@ -133,9 +107,9 @@ describe('useSettingUI', () => {
undefined,
'Comfy.Locale'
)
expect(defaultCategory.value).toBe(
findCategory(settingCategories.value, 'Comfy')
)
const comfyCategory = findCategory(settingCategories.value, 'Comfy')
expect(comfyCategory).toBeDefined()
expect(defaultCategory.value).toBe(comfyCategory)
})
it('resolves different category from scrollToSettingId', () => {
@@ -147,6 +121,7 @@ describe('useSettingUI', () => {
settingCategories.value,
'Appearance'
)
expect(appearanceCategory).toBeDefined()
expect(defaultCategory.value).toBe(appearanceCategory)
})
@@ -162,82 +137,4 @@ describe('useSettingUI', () => {
const { defaultCategory } = useSettingUI('about', 'Comfy.Locale')
expect(defaultCategory.value.key).toBe('about')
})
it('falls back when defaultPanel is not in the menu', () => {
const missingPanel = 'missing' as unknown as Parameters<
typeof useSettingUI
>[0]
const { defaultCategory, settingCategories } = useSettingUI(missingPanel)
expect(defaultCategory.value).toBe(settingCategories.value[0])
})
it('moves floating settings into Other and hides Vue-node-only settings', () => {
settingsById = {
Floating: {
id: 'Floating',
name: 'Floating',
type: 'boolean',
defaultValue: false
},
'Hidden.Setting': {
id: 'Hidden.Setting',
name: 'Hidden',
type: 'hidden',
defaultValue: false
},
'Vue.Hidden': {
id: 'Vue.Hidden',
name: 'Vue Hidden',
type: 'boolean',
defaultValue: false,
hideInVueNodes: true
}
}
vi.mocked(useSettingStore).mockReturnValue({
settingsById
} as ReturnType<typeof useSettingStore>)
vueFlags.shouldRenderVueNodes.value = true
const { settingCategories } = useSettingUI()
expect(settingCategories.value.map((category) => category.label)).toEqual([
'Other'
])
expect(
settingCategories.value[0].children?.map((node) => node.key)
).toEqual(['root/Floating'])
})
it('adds gated cloud, desktop, workspace, and secrets panels', () => {
auth.isLoggedIn.value = true
billing.isActiveSubscription.value = true
dist.isCloud = true
dist.isDesktop = true
featureFlags.teamWorkspacesEnabled = true
featureFlags.userSecretsEnabled = true
Object.assign(window, { __CONFIG__: { subscription_required: true } })
const { findCategoryByKey, findPanelByKey, navGroups, panels } =
useSettingUI()
expect(panels.value.map((panel) => panel.node.key)).toEqual([
'about',
'credits',
'user',
'workspace',
'keybinding',
'extension',
'server-config',
'subscription',
'secrets'
])
expect(navGroups.value.map((group) => group.title)).toEqual([
'Workspace',
'General'
])
expect(findCategoryByKey('secrets')?.key).toBe('secrets')
expect(findCategoryByKey('missing')).toBeNull()
expect(findPanelByKey('subscription')?.node.key).toBe('subscription')
expect(findPanelByKey('missing')).toBeNull()
})
})

View File

@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
})
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
const b: TelemetryProvider = {}
const registry = new TelemetryRegistry()
registry.registerProvider(a)
registry.registerProvider(b)
const metadata = {
user_id: 'user-1',
tier: 'pro' as const,
cycle: 'monthly' as const,
checkout_type: 'new' as const,
payment_intent_source: 'subscribe_to_run' as const
}
registry.trackBeginCheckout(metadata)
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
})
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
const provider: TelemetryProvider = {
trackAddApiCreditButtonClicked: vi.fn()
}
const registry = new TelemetryRegistry()
registry.registerProvider(provider)
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(
provider.trackAddApiCreditButtonClicked
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
})
it('skips providers that do not implement trackSearchQuery', () => {
const empty: TelemetryProvider = {}
const registry = new TelemetryRegistry()

View File

@@ -1,6 +1,7 @@
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
}
trackAddApiCreditButtonClicked(): void {
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.dispatch((provider) =>
provider.trackAddApiCreditButtonClicked?.(metadata)
)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
)
})
it('captures begin_checkout with intent metadata', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackBeginCheckout({
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.BEGIN_CHECKOUT,
{
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
}
)
})
it('captures add-credit clicks with their source', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'credits_panel' }
)
})
it('captures share attribution events', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()

View File

@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
import type { RemoteConfig } from '@/platform/remoteConfig/types'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
EnterLinearMetadata,
ShareFlowMetadata,
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
this.trackEvent(eventName, metadata)
}
trackAddApiCreditButtonClicked(): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
}
trackMonthlySubscriptionSucceeded(

View File

@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
)
})
it('forwards add-credit clicks with their source', () => {
new HostTelemetrySink().trackAddApiCreditButtonClicked({
source: 'avatar_menu'
})
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'avatar_menu' }
)
})
it('does nothing when the host bridge is absent', () => {
delete window.__comfyDesktop2

View File

@@ -10,6 +10,7 @@ import {
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
}
trackAddApiCreditButtonClicked(): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -12,12 +12,29 @@
* 3. Check dist/assets/*.js files contain no tracking code
*/
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { AuditLog } from '@/services/customerEventsService'
import type { AppMode } from '@/utils/appMode'
export type PaymentIntentSource =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
| 'subscribe_to_run'
| 'subscribe_now_button'
| 'upgrade_to_add_credits'
| 'settings_billing_panel'
| 'avatar_menu_plans'
| 'team_members_panel'
| 'invite_member_upsell'
| 'upload_model_upgrade'
| 'team_upgrade_resume'
export type SubscriptionCheckoutType = 'new' | 'change'
export type SubscriptionCheckoutTier = TierKey | 'team'
/**
* Authentication metadata for sign-up tracking
*/
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
export interface SubscriptionMetadata {
current_tier?: string
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
}
export interface AddCreditsClickMetadata {
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
}
export interface BeginCheckoutMetadata
extends Record<string, unknown>, CheckoutAttributionMetadata {
user_id: string
tier: TierKey
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
checkout_attempt_id?: string
billing_op_id?: string
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
}
interface EcommerceItemMetadata {
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
checkout_attempt_id: string
tier: TierKey
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
value: number
currency: string
ecommerce: EcommerceMetadata
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
metadata?: SubscriptionSuccessMetadata
): void
trackMonthlySubscriptionCancelled?(): void
trackAddApiCreditButtonClicked?(): void
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
trackApiCreditTopupSucceeded?(): void
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void

View File

@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}

View File

@@ -391,12 +391,13 @@ const showZeroState = computed(
)
function handleSubscribeWorkspace() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleUpgrade() {
if (isFreeTierPlan.value) showPricingTable()
else showSubscriptionDialog()
if (isFreeTierPlan.value)
showPricingTable({ reason: 'settings_billing_panel' })
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleViewMoreDetails() {

View File

@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
import { useEventListener } from '@vueuse/core'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
const { onClose, reason, initialPlanMode } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
initialPlanMode?: 'personal' | 'team'
}>()
@@ -152,7 +152,7 @@ const {
handleConfirmTransition,
handleTeamSubscribe,
handleResubscribe
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
// Backspace mirrors the back arrow on the confirm step, but never while an
// editable element is focused (let it delete text there).

View File

@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
import { createI18n } from 'vue-i18n'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
const mockHandleSuccessClose = vi.fn()
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
const mockPreviewData = ref<{ transition_type: string } | null>(null)
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
useSubscriptionCheckout: () => ({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
useSubscriptionCheckout: mockUseSubscriptionCheckout
}))
const i18n = createI18n({
@@ -91,7 +76,7 @@ const SuccessStub = {
function renderComponent(
props: {
onClose?: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
} = {}
) {
@@ -121,6 +106,23 @@ function renderComponent(
describe('SubscriptionRequiredDialogContentWorkspace', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseSubscriptionCheckout.mockReturnValue({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
mockCheckoutStep.value = 'pricing'
mockPreviewData.value = null
})
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
})
it('passes the reason into subscription checkout', () => {
renderComponent({ reason: 'out_of_credits' })
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
expect.any(Function),
'out_of_credits'
)
})
it('shows the team workspace header by default', () => {
renderComponent()
expect(screen.getByText('Team Workspace')).toBeInTheDocument()

View File

@@ -116,7 +116,7 @@
import { cn } from '@comfyorg/tailwind-utils'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import PricingTableWorkspace from './PricingTableWorkspace.vue'
@@ -130,7 +130,7 @@ const {
isPersonal = false
} = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
}>()
@@ -154,7 +154,7 @@ const {
handleConfirmTransition,
handleResubscribe,
handleSuccessClose
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
</script>
<style scoped>

View File

@@ -61,6 +61,9 @@ function onDismiss() {
function onUpgrade() {
dialogStore.closeDialog({ key: 'invite-member-upsell' })
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({
planMode: 'team',
reason: 'invite_member_upsell'
})
}
</script>

View File

@@ -277,7 +277,7 @@ export function useMembersPanel() {
}
function showTeamPlans() {
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
}
return {

View File

@@ -1,8 +1,9 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed } from 'vue'
import { computed, reactive } from 'vue'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import type { Plan } from '@/platform/workspace/api/workspaceApi'
import { findPlanSlug } from './useSubscriptionCheckout'
@@ -75,7 +76,9 @@ const {
mockPlans,
mockResubscribe,
mockToastAdd,
mockStartOperation
mockStartOperation,
mockTrackBeginCheckout,
mockUserId
} = vi.hoisted(() => ({
mockSubscribe: vi.fn(),
mockPreviewSubscribe: vi.fn(),
@@ -84,7 +87,9 @@ const {
mockPlans: { value: [] as Plan[] },
mockResubscribe: vi.fn(),
mockToastAdd: vi.fn(),
mockStartOperation: vi.fn()
mockStartOperation: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
useTelemetry: () => ({
trackMonthlySubscriptionSucceeded: vi.fn(),
trackBeginCheckout: mockTrackBeginCheckout
})
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
vi.mock('vue-i18n', async (importOriginal) => {
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
describe('useSubscriptionCheckout', () => {
let emit: ReturnType<typeof vi.fn>
async function setup() {
async function setup(paymentIntentSource?: PaymentIntentSource) {
const { useSubscriptionCheckout } =
await import('./useSubscriptionCheckout')
return useSubscriptionCheckout(emit as never)
return useSubscriptionCheckout(emit as never, paymentIntentSource)
}
beforeEach(() => {
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
vi.clearAllMocks()
mockPlans.value = allPlans()
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
mockUserId.value = 'user-1'
emit = vi.fn()
})
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
cancelUrl: 'https://platform.comfy.org/payment/failed'
})
expect(checkout.checkoutStep.value).toBe('success')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
checkout_type: 'new',
billing_op_id: 'op-team-1'
})
)
})
it('uses the annual plan slug for the yearly cycle', async () => {
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
detail: 'Team payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('keeps team checkout_type as change when the preview request fails', async () => {
const checkout = await setup()
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
await checkout.handleSubscribeTeamClick({
stop: {
id: 'team_1400',
usd: 1400,
credits: 295_400,
discountedUsd: 1295
},
billingCycle: 'monthly',
isChange: true
})
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-team-change'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleTeamSubscribe()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
cycle: 'monthly',
checkout_type: 'change',
billing_op_id: 'op-team-change'
})
)
})
})
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
expect(checkout.checkoutStep.value).toBe('success')
})
it('skips begin_checkout when no user id is available', async () => {
mockUserId.value = null
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
mockUserId.value = 'user-1'
})
it('fires begin_checkout carrying the payment intent source', async () => {
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'standard',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op-1',
payment_intent_source: 'subscribe_to_run'
})
})
it('opens payment URL when needs_payment_method', async () => {
const checkout = await setup()
checkout.selectedTierKey.value = 'standard'
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
detail: 'Payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
})

View File

@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type {
Plan,
PreviewSubscribeResponse,
SubscribeResponse
} from '@/platform/workspace/api/workspaceApi'
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
type CheckoutStep = 'pricing' | 'preview' | 'success'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
interface SelectedTeamCheckout {
stop: TeamPlanSelection
checkoutType: SubscriptionCheckoutType
}
/**
* Which screen the `preview` step shows. Only a change prorates: a team change
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
@@ -45,9 +55,12 @@ export function findPlanSlug(
return plan?.slug ?? null
}
export function useSubscriptionCheckout(emit: {
(e: 'close', subscribed: boolean): void
}) {
export function useSubscriptionCheckout(
emit: {
(e: 'close', subscribed: boolean): void
},
paymentIntentSource?: PaymentIntentSource
) {
const { t } = useI18n()
const toast = useToast()
const {
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
const isResubscribing = ref(false)
const previewData = ref<PreviewSubscribeResponse | null>(null)
const selectedTierKey = ref<CheckoutTierKey | null>(null)
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
const selectedBillingCycle = ref<BillingCycle>('yearly')
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
const selectedTeamStop = computed(
() => selectedTeamCheckout.value?.stop ?? null
)
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
const previewVariant = computed<PreviewVariant>(() => {
if (selectedTeamStop.value) {
if (selectedTeamCheckout.value) {
return previewData.value ? 'team-change' : 'team-new'
}
if (previewData.value) {
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
billingCycle: BillingCycle
isChange?: boolean
}) {
selectedTeamStop.value = payload.stop
selectedTeamCheckout.value = {
stop: payload.stop,
checkoutType: payload.isChange ? 'change' : 'new'
}
selectedBillingCycle.value = payload.billingCycle
selectedTierKey.value = null
previewData.value = null
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
function handleBackToPricing() {
checkoutStep.value = 'pricing'
previewData.value = null
selectedTeamStop.value = null
selectedTeamCheckout.value = null
}
function handleSuccessClose() {
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
}
async function handleSubscription() {
if (!selectedTierKey.value) return
const tierKey = selectedTierKey.value
if (!tierKey) return
const billingCycle = selectedBillingCycle.value
const checkoutType =
previewData.value &&
previewData.value.transition_type !== 'new_subscription'
? 'change'
: 'new'
isSubscribing.value = true
try {
const planSlug = getApiPlanSlug(
selectedTierKey.value,
selectedBillingCycle.value
)
const planSlug = getApiPlanSlug(tierKey, billingCycle)
if (!planSlug) return
const response = await subscribe(planSlug, {
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: tierKey,
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
}
async function handleTeamSubscription() {
const stop = selectedTeamStop.value
if (!stop?.id) {
const teamCheckout = selectedTeamCheckout.value
if (!teamCheckout?.stop.id) {
toast.add({
severity: 'error',
summary: t('subscription.teamPlan.name'),
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
return
}
const { stop, checkoutType } = teamCheckout
const billingCycle = selectedBillingCycle.value
isSubscribing.value = true
try {
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
const planSlug = getTeamPlanSlug(billingCycle)
const response = await subscribe(planSlug, {
teamCreditStopId: stop.id,
billingCycle: selectedBillingCycle.value,
billingCycle,
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)

View File

@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingBalanceResponse,
BillingStatusResponse,
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
subscriptionDialog.show()
subscriptionDialog.show({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
subscriptionDialog.show()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
subscriptionDialog.show(options)
}
return {

View File

@@ -0,0 +1,38 @@
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutTier,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useAuthStore } from '@/stores/authStore'
interface TrackWorkspaceCheckoutStartedOptions {
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkoutType: SubscriptionCheckoutType
billingOpId: string
paymentIntentSource?: PaymentIntentSource
}
export function trackWorkspaceCheckoutStarted({
tier,
cycle,
checkoutType,
billingOpId,
paymentIntentSource
}: TrackWorkspaceCheckoutStartedOptions) {
const { userId } = useAuthStore()
if (!userId) return
useTelemetry()?.trackBeginCheckout({
user_id: userId,
tier,
cycle,
checkout_type: checkoutType,
billing_op_id: billingOpId,
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {})
})
}

View File

@@ -0,0 +1,208 @@
import { createTestingPinia } from '@pinia/testing'
import { render, screen, within } from '@testing-library/vue'
import { setActivePinia } from 'pinia'
import { createI18n } from 'vue-i18n'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { NodeError } from '@/schemas/apiSchema'
import LinearControls from '@/renderer/extensions/linearMode/LinearControls.vue'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
import { useAppModeStore } from '@/stores/appModeStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
import { toNodeId } from '@/types/nodeId'
const billingMock = vi.hoisted(() => ({
isActiveSubscription: true
}))
const overlayMock = vi.hoisted(() => ({
overlayMessage: 'KSampler is missing a required input: model',
overlayTitle: 'Required input missing'
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isActiveSubscription: billingMock.isActiveSubscription
})
}))
vi.mock('@/components/error/useErrorOverlayState', () => ({
useErrorOverlayState: () => ({
overlayMessage: overlayMock.overlayMessage,
overlayTitle: overlayMock.overlayTitle
})
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: {
en: {
linearMode: {
error: {
goto: 'Show errors in graph'
},
mobileNoWorkflow: 'No workflow',
runCount: 'Run count',
viewJob: 'View job'
},
menu: {
run: 'Run'
},
menuLabels: {
publish: 'Publish'
},
queue: {
jobAddedToQueue: 'Job added to queue',
jobQueueing: 'Queueing'
}
}
}
})
const nodeErrors: Record<string, NodeError> = {
'1': {
class_type: 'TestNode',
dependent_outputs: [],
errors: [
{
type: 'required_input_missing',
message: 'Missing input',
details: '',
extra_info: { input_name: 'prompt' }
}
]
}
}
function renderControls({
hasError = false,
isActiveSubscription = true,
mobile = false
}: {
hasError?: boolean
isActiveSubscription?: boolean
mobile?: boolean
} = {}) {
billingMock.isActiveSubscription = isActiveSubscription
const pinia = createTestingPinia({
createSpy: vi.fn,
stubActions: false
})
setActivePinia(pinia)
useAppModeStore().selectedOutputs = [toNodeId(1)]
if (hasError) {
useExecutionErrorStore().lastNodeErrors = nodeErrors
}
const toastTarget = document.createElement('div')
return render(LinearControls, {
props: { mobile, toastTo: toastTarget },
global: {
plugins: [pinia, i18n],
stubs: {
AppModeWidgetList: true,
Loader: true,
PartnerNodesList: true,
Popover: {
template: '<div><slot name="button" /><slot /></div>'
},
ScrubableNumberInput: true,
SubscribeToRunButton: true
}
}
})
}
describe('LinearControls', () => {
beforeEach(() => {
vi.clearAllMocks()
billingMock.isActiveSubscription = true
overlayMock.overlayMessage = 'KSampler is missing a required input: model'
overlayMock.overlayTitle = 'Required input missing'
})
it.for([
{ label: 'desktop', mobile: false },
{ label: 'mobile', mobile: true }
])('shows a workflow error warning in $label controls', ({ mobile }) => {
renderControls({ hasError: true, mobile })
const warning = screen.getByRole('status')
expect(
within(warning).getByText('Required input missing')
).toBeInTheDocument()
expect(
within(warning).getByText('KSampler is missing a required input: model')
).toBeInTheDocument()
expect(
within(warning).getByRole('button', { name: 'Show errors in graph' })
).toBeInTheDocument()
expect(within(warning).queryByLabelText('Close')).not.toBeInTheDocument()
const runButton = screen.getByRole('button', { name: 'Run' })
expect(runButton).toHaveAttribute(
'aria-describedby',
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
)
const description = screen.getByTestId(
'linear-validation-warning-description'
)
expect(description).toHaveAttribute(
'id',
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
)
expect(description).toHaveTextContent('Required input missing')
expect(description).toHaveTextContent(
'KSampler is missing a required input: model'
)
expect(description).not.toHaveTextContent('Show errors in graph')
})
it.for([
{ label: 'desktop', mobile: false },
{ label: 'mobile', mobile: true }
])(
'does not show the workflow error warning in $label controls without graph errors',
({ mobile }) => {
renderControls({ mobile })
expect(screen.queryByRole('status')).not.toBeInTheDocument()
expect(
screen.queryByRole('button', { name: 'Show errors in graph' })
).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Run' })).not.toHaveAttribute(
'aria-describedby'
)
}
)
it.for([
{ label: 'desktop', mobile: false },
{ label: 'mobile', mobile: true }
])(
'does not show the workflow error warning in $label controls without an active subscription',
({ mobile }) => {
renderControls({
hasError: true,
isActiveSubscription: false,
mobile
})
expect(screen.queryByRole('status')).not.toBeInTheDocument()
}
)
it('does not show the warning when the error copy is empty', () => {
overlayMock.overlayMessage = ''
renderControls({ hasError: true })
expect(screen.queryByRole('status')).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Run' })).not.toHaveAttribute(
'aria-describedby'
)
})
})

View File

@@ -1,10 +1,11 @@
<script setup lang="ts">
import { useTimeout } from '@vueuse/core'
import { storeToRefs } from 'pinia'
import { ref, useTemplateRef } from 'vue'
import { computed, ref, toValue, useTemplateRef } from 'vue'
import { useI18n } from 'vue-i18n'
import AppModeWidgetList from '@/components/builder/AppModeWidgetList.vue'
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
import Loader from '@/components/loader/Loader.vue'
import ScrubableNumberInput from '@/components/common/ScrubableNumberInput.vue'
import Popover from '@/components/ui/Popover.vue'
@@ -14,11 +15,15 @@ import SubscribeToRunButton from '@/platform/cloud/subscription/components/Subsc
import { useSettingStore } from '@/platform/settings/settingStore'
import { useTelemetry } from '@/platform/telemetry'
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
import LinearRunErrorWarning from '@/renderer/extensions/linearMode/LinearRunErrorWarning.vue'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
import PartnerNodesList from '@/renderer/extensions/linearMode/PartnerNodesList.vue'
import { useCommandStore } from '@/stores/commandStore'
import { useQueueSettingsStore } from '@/stores/queueStore'
import { useAppMode } from '@/composables/useAppMode'
import { useAppModeStore } from '@/stores/appModeStore'
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
const { t } = useI18n()
const commandStore = useCommandStore()
const { batchCount } = storeToRefs(useQueueSettingsStore())
@@ -28,6 +33,8 @@ const workflowStore = useWorkflowStore()
const { isBuilderMode } = useAppMode()
const appModeStore = useAppModeStore()
const { hasOutputs } = storeToRefs(appModeStore)
const { hasAnyError } = storeToRefs(useExecutionErrorStore())
const { overlayMessage } = useErrorOverlayState()
const { toastTo, mobile } = defineProps<{
toastTo?: string | HTMLElement
@@ -43,6 +50,13 @@ const { ready: jobToastTimeout, start: resetJobToastTimeout } = useTimeout(
{ controls: true, immediate: false }
)
const widgetListRef = useTemplateRef('widgetListRef')
const linearRunButtonTestId = 'linear-run-button'
const showRunErrorWarning = computed(
() =>
hasAnyError.value &&
toValue(isActiveSubscription) &&
toValue(overlayMessage).trim().length > 0
)
//TODO: refactor out of this file.
//code length is small, but changes should propagate
@@ -134,9 +148,10 @@ function handleDragDrop() {
<PartnerNodesList v-if="!mobile" />
<section
v-if="mobile"
data-testid="linear-run-button"
:data-testid="linearRunButtonTestId"
class="border-t border-node-component-border p-4 pb-6"
>
<LinearRunErrorWarning v-if="showRunErrorWarning" />
<SubscribeToRunButton
v-if="!isActiveSubscription"
class="mt-4 w-full"
@@ -166,18 +181,24 @@ function handleDragDrop() {
variant="primary"
class="grow"
size="lg"
:aria-describedby="
showRunErrorWarning
? LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
: undefined
"
@click="runButtonClick"
>
<i class="icon-[lucide--play]" />
<i aria-hidden="true" class="icon-[lucide--play]" />
{{ t('menu.run') }}
</Button>
</div>
</section>
<section
v-else
data-testid="linear-run-button"
:data-testid="linearRunButtonTestId"
class="border-t border-node-component-border p-4 pb-6"
>
<LinearRunErrorWarning v-if="showRunErrorWarning" />
<div
class="m-1 mb-2 text-node-component-slot-text"
v-text="t('linearMode.runCount')"
@@ -198,9 +219,14 @@ function handleDragDrop() {
variant="primary"
class="mt-4 w-full text-sm"
size="lg"
:aria-describedby="
showRunErrorWarning
? LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
: undefined
"
@click="runButtonClick"
>
<i class="icon-[lucide--play]" />
<i aria-hidden="true" class="icon-[lucide--play]" />
{{ t('menu.run') }}
</Button>
</section>

View File

@@ -0,0 +1,92 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { createI18n } from 'vue-i18n'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import LinearRunErrorWarning from '@/renderer/extensions/linearMode/LinearRunErrorWarning.vue'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
const mocks = vi.hoisted(() => ({
overlayMessage: 'KSampler is missing a required input: model',
overlayTitle: 'Required input missing',
viewErrorsInGraph: vi.fn()
}))
vi.mock('@/components/error/useErrorOverlayState', () => ({
useErrorOverlayState: () => ({
overlayMessage: mocks.overlayMessage,
overlayTitle: mocks.overlayTitle
})
}))
vi.mock('@/composables/useViewErrorsInGraph', () => ({
useViewErrorsInGraph: () => ({
viewErrorsInGraph: mocks.viewErrorsInGraph
})
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: {
en: {
linearMode: {
error: {
goto: 'Show errors in graph'
}
}
}
}
})
function renderWarning() {
const user = userEvent.setup()
const result = render(LinearRunErrorWarning, {
global: { plugins: [i18n] }
})
return { ...result, user }
}
describe('LinearRunErrorWarning', () => {
beforeEach(() => {
mocks.viewErrorsInGraph.mockReset()
})
it('shows the current error overlay title and message without a close action', () => {
renderWarning()
const warning = screen.getByRole('status')
expect(warning).toHaveTextContent('Required input missing')
expect(warning).toHaveTextContent(
'KSampler is missing a required input: model'
)
expect(screen.getByText('Required input missing')).toHaveAttribute(
'title',
'Required input missing'
)
const description = screen.getByTestId(
'linear-validation-warning-description'
)
expect(description).toHaveAttribute(
'id',
LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID
)
expect(description).toHaveTextContent('Required input missing')
expect(description).toHaveTextContent(
'KSampler is missing a required input: model'
)
expect(description).not.toHaveTextContent('Show errors in graph')
expect(screen.queryByLabelText('Close')).not.toBeInTheDocument()
})
it('opens graph errors when the action is clicked', async () => {
const { user } = renderWarning()
await user.click(
screen.getByRole('button', { name: 'Show errors in graph' })
)
expect(mocks.viewErrorsInGraph).toHaveBeenCalledOnce()
})
})

View File

@@ -0,0 +1,63 @@
<script setup lang="ts">
import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import { useErrorOverlayState } from '@/components/error/useErrorOverlayState'
import { useViewErrorsInGraph } from '@/composables/useViewErrorsInGraph'
import { LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID } from '@/renderer/extensions/linearMode/linearRunErrorWarningIds'
const { t } = useI18n()
const { viewErrorsInGraph } = useViewErrorsInGraph()
const { overlayMessage, overlayTitle } = useErrorOverlayState()
</script>
<template>
<div
role="status"
data-testid="linear-validation-warning"
class="mb-3 flex w-full flex-col gap-2 overflow-hidden rounded-lg border border-l-4 border-border-default border-l-destructive-background bg-base-background p-3 shadow-interface transition-colors duration-200 ease-in-out"
>
<div
:id="LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID"
data-testid="linear-validation-warning-description"
class="flex flex-col gap-2"
>
<div class="flex w-full items-start gap-2">
<i
aria-hidden="true"
class="mt-0.5 icon-[lucide--circle-x] size-4 shrink-0 text-destructive-background"
/>
<span
class="min-w-0 flex-1 truncate text-sm text-base-foreground"
:title="overlayTitle"
>
{{ overlayTitle }}
</span>
</div>
<div
class="flex w-full items-start gap-2"
data-testid="linear-validation-warning-message"
>
<span class="size-4 shrink-0" aria-hidden="true" />
<p
class="m-0 line-clamp-3 min-w-0 flex-1 text-sm/snug wrap-break-word whitespace-pre-wrap text-muted-foreground"
>
{{ overlayMessage }}
</p>
</div>
</div>
<div class="flex w-full items-center justify-end pt-2">
<Button
variant="secondary"
size="unset"
class="min-h-8 rounded-lg px-3 py-2 text-xs font-normal"
data-testid="linear-view-errors"
@click="viewErrorsInGraph"
>
{{ t('linearMode.error.goto') }}
</Button>
</div>
</div>
</template>

View File

@@ -0,0 +1,2 @@
export const LINEAR_RUN_ERROR_WARNING_DESCRIPTION_ID =
'linear-run-error-warning'

View File

@@ -1,225 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { VueNodeData } from '@/composables/graph/useGraphNodeManager'
import { LGraphBadge } from '@/lib/litegraph/src/litegraph'
import type { INodeInputSlot } from '@/lib/litegraph/src/interfaces'
import {
trackNodePrice,
usePartitionedBadges
} from '@/renderer/extensions/vueNodes/composables/usePartitionedBadges'
import { toNodeId } from '@/types/nodeId'
import { NodeBadgeMode } from '@/types/nodeSource'
const { settings, nodeDefs, pricing, getNodeRevisionRefMock, getWidgetMock } =
vi.hoisted(() => ({
settings: {} as Record<string, unknown>,
nodeDefs: {} as Record<string, unknown>,
pricing: {
dynamic: false,
widgets: [] as string[],
inputs: [] as string[],
groups: [] as string[]
},
getNodeRevisionRefMock: vi.fn(() => ({ value: 0 })),
getWidgetMock: vi.fn(() => ({ value: 'widget-value' }))
}))
vi.mock('@/scripts/app', () => ({
app: {
canvas: { graph: { getNodeById: () => null, rootGraph: { id: 'g1' } } }
}
}))
vi.mock('@/composables/node/useNodePricing', () => ({
useNodePricing: () => ({
getRelevantWidgetNames: () => pricing.widgets,
hasDynamicPricing: () => pricing.dynamic,
getInputGroupPrefixes: () => pricing.groups,
getInputNames: () => pricing.inputs,
getNodeRevisionRef: getNodeRevisionRefMock
})
}))
vi.mock('@/composables/node/usePriceBadge', () => ({
usePriceBadge: () => ({
isCreditsBadge: (b: { text?: string }) => b.text?.startsWith('$') ?? false
})
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => ({ get: (key: string) => settings[key] })
}))
vi.mock('@/stores/nodeDefStore', () => ({
useNodeDefStore: () => ({ nodeDefsByName: nodeDefs })
}))
vi.mock('@/stores/widgetValueStore', () => ({
useWidgetValueStore: () => ({ getWidget: getWidgetMock })
}))
function nodeData(overrides: Partial<VueNodeData> = {}): VueNodeData {
return {
executing: false,
id: toNodeId(1),
mode: 0,
selected: false,
title: 'Test node',
type: 'TestNode',
apiNode: false,
badges: [],
inputs: [],
...overrides
} satisfies VueNodeData
}
function inputSlot(
name: string,
readLink: () => number | null
): INodeInputSlot {
return {
name,
type: '*',
boundingRect: [0, 0, 0, 0],
get link() {
return readLink()
},
set link(_value: number | null) {}
} as INodeInputSlot
}
function badge(text: string): LGraphBadge {
return new LGraphBadge({ text })
}
beforeEach(() => {
settings['Comfy.NodeBadge.NodeSourceBadgeMode'] = NodeBadgeMode.None
settings['Comfy.NodeBadge.NodeLifeCycleBadgeMode'] = NodeBadgeMode.None
settings['Comfy.NodeBadge.NodeIdBadgeMode'] = NodeBadgeMode.None
for (const k of Object.keys(nodeDefs)) delete nodeDefs[k]
nodeDefs['TestNode'] = { isCoreNode: false }
pricing.dynamic = false
pricing.widgets = []
pricing.inputs = []
pricing.groups = []
getNodeRevisionRefMock.mockClear()
getWidgetMock.mockClear()
})
describe('usePartitionedBadges', () => {
it('emits no core badges when every badge mode is None', () => {
const result = usePartitionedBadges(nodeData()).value
expect(result.core).toEqual([])
})
it('tracks dynamic-pricing dependencies for an api node without throwing', () => {
pricing.dynamic = true
pricing.widgets = ['seed']
pricing.inputs = ['model']
pricing.groups = ['lora']
const result = usePartitionedBadges(
nodeData({
apiNode: true,
inputs: [
inputSlot('model', () => 1),
inputSlot('lora.0', () => 2),
inputSlot('unrelated', () => null)
]
})
).value
expect(result).toHaveProperty('core')
expect(result).toHaveProperty('extension')
})
it('adds an id badge when the id mode is enabled', () => {
settings['Comfy.NodeBadge.NodeIdBadgeMode'] = NodeBadgeMode.ShowAll
const result = usePartitionedBadges(nodeData({ id: toNodeId(7) })).value
expect(result.core).toContainEqual({ text: '#7' })
})
it('adds a lifecycle badge, trimmed of brackets', () => {
settings['Comfy.NodeBadge.NodeLifeCycleBadgeMode'] = NodeBadgeMode.ShowAll
nodeDefs['TestNode'] = {
isCoreNode: false,
nodeLifeCycleBadgeText: '[BETA]'
}
const result = usePartitionedBadges(nodeData()).value
expect(result.core).toContainEqual({ text: 'BETA' })
})
it('adds a source badge for non-core nodes when source mode is on', () => {
settings['Comfy.NodeBadge.NodeSourceBadgeMode'] = NodeBadgeMode.ShowAll
nodeDefs['TestNode'] = {
isCoreNode: false,
nodeSource: { badgeText: 'my-pack' }
}
const result = usePartitionedBadges(nodeData()).value
expect(result.core).toContainEqual({ text: 'my-pack' })
})
it('partitions extension badges (skipping the first) from credits badges', () => {
const result = usePartitionedBadges(
nodeData({
badges: [badge('skipped'), badge('ext-badge'), badge('$5 per run')]
})
).value
expect(result.extension.map((badge) => badge.text)).toEqual(['ext-badge'])
expect(result.pricing).toEqual([{ required: '$5', rest: 'per run' }])
})
it('flags hasComfyBadge for a core node with source ShowAll and no pricing', () => {
settings['Comfy.NodeBadge.NodeSourceBadgeMode'] = NodeBadgeMode.ShowAll
nodeDefs['TestNode'] = { isCoreNode: true }
const result = usePartitionedBadges(
nodeData({ badges: [badge('x')] })
).value
expect(result.hasComfyBadge).toBe(true)
})
})
describe('trackNodePrice', () => {
it('no-ops for a node without dynamic pricing', () => {
pricing.dynamic = false
trackNodePrice({ id: '1', type: 'Static', inputs: [] })
expect(getNodeRevisionRefMock).toHaveBeenCalledWith(toNodeId('1'))
expect(getWidgetMock).not.toHaveBeenCalled()
})
it('touches widget, input, and input-group pricing dependencies', () => {
pricing.dynamic = true
pricing.widgets = ['seed']
pricing.inputs = ['model']
pricing.groups = ['lora']
let modelReads = 0
let groupReads = 0
let unrelatedReads = 0
trackNodePrice({
id: '2',
type: 'Dynamic',
inputs: [
inputSlot('model', () => {
modelReads += 1
return 1
}),
inputSlot('lora.0', () => {
groupReads += 1
return 2
}),
inputSlot('unrelated', () => {
unrelatedReads += 1
return null
})
]
})
expect(getNodeRevisionRefMock).toHaveBeenCalledWith(toNodeId('2'))
expect(getWidgetMock).toHaveBeenCalled()
expect(modelReads).toBe(1)
expect(groupReads).toBe(1)
expect(unrelatedReads).toBe(0)
})
})

View File

@@ -0,0 +1,57 @@
import { describe, expect, it } from 'vitest'
import {
getComboSpecComboOptions,
getInputSpecType,
isComboInputSpec,
isComboInputSpecV1,
isComboInputSpecV2,
isFloatInputSpec,
isIntInputSpec,
isMediaUploadComboInput
} from './nodeDefSchema'
import type {
ComboInputSpec,
ComboInputSpecV2,
InputSpec
} from './nodeDefSchema'
describe('node definition schema helpers', () => {
it('identifies input spec variants', () => {
const intSpec: InputSpec = ['INT', {}]
const floatSpec: InputSpec = ['FLOAT', {}]
const comboV1: ComboInputSpec = [['a', 'b'], {}]
const comboV2: ComboInputSpecV2 = ['COMBO', { options: ['a', 'b'] }]
expect(isIntInputSpec(intSpec)).toBe(true)
expect(isFloatInputSpec(floatSpec)).toBe(true)
expect(isComboInputSpecV1(comboV1)).toBe(true)
expect(isComboInputSpecV2(comboV2)).toBe(true)
expect(isComboInputSpec(comboV1)).toBe(true)
expect(isComboInputSpec(comboV2)).toBe(true)
expect(getInputSpecType(comboV1)).toBe('COMBO')
expect(getInputSpecType(intSpec)).toBe('INT')
})
it('reads combo options from legacy and v2 combo specs', () => {
expect(getComboSpecComboOptions([['a', 1], {}])).toEqual(['a', 1])
expect(
getComboSpecComboOptions(['COMBO', { options: ['x', 'y'] }])
).toEqual(['x', 'y'])
expect(getComboSpecComboOptions(['COMBO', {}])).toEqual([])
})
it('detects media upload combo inputs', () => {
expect(isMediaUploadComboInput([['a'], { image_upload: true }])).toBe(true)
expect(
isMediaUploadComboInput(['COMBO', { animated_image_upload: true }])
).toBe(true)
expect(isMediaUploadComboInput(['COMBO', { video_upload: true }])).toBe(
true
)
expect(isMediaUploadComboInput(['STRING', { image_upload: true }])).toBe(
false
)
expect(isMediaUploadComboInput(['COMBO', undefined])).toBe(false)
})
})

View File

@@ -0,0 +1,149 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
fetchWithUnifiedRemint,
shouldRemintCloudRequest
} from '@/platform/auth/unified/remintRetry'
const { mockAuthStore } = vi.hoisted(() => ({
mockAuthStore: {
isInitialized: true,
getAuthHeader: vi.fn(),
getAuthToken: vi.fn()
}
}))
vi.mock('@/platform/distribution/types', () => ({ isCloud: true }))
vi.mock('@/stores/authStore', () => ({
useAuthStore: vi.fn(() => mockAuthStore)
}))
vi.mock('@/platform/auth/unified/remintRetry', () => ({
fetchWithUnifiedRemint: vi.fn(),
shouldRemintCloudRequest: vi.fn()
}))
class FakeWebSocket extends EventTarget {
static instances: FakeWebSocket[] = []
binaryType = ''
sent: string[] = []
constructor(readonly url: string) {
super()
FakeWebSocket.instances.push(this)
}
send(data: string) {
this.sent.push(data)
}
close() {
this.dispatchEvent(new Event('close'))
}
}
const { ComfyApi } = await import('./api')
describe('ComfyApi cloud mode', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.unstubAllGlobals()
FakeWebSocket.instances = []
window.name = ''
sessionStorage.clear()
mockAuthStore.isInitialized = true
mockAuthStore.getAuthHeader.mockResolvedValue(null)
mockAuthStore.getAuthToken.mockResolvedValue(null)
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(false)
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
new Response(JSON.stringify({ ok: true }), {
headers: { 'Content-Type': 'application/json' }
})
)
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
})
it('adds cloud auth headers and enables unified retry for authenticated requests', async () => {
mockAuthStore.getAuthHeader.mockResolvedValue({
Authorization: 'Bearer firebase-token'
})
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(true)
const api = new ComfyApi()
api.user = 'cloud-user'
await api.fetchApi('/queue')
expect(api.api_base).toBe('')
expect(fetchWithUnifiedRemint).toHaveBeenCalledWith(
'/api/queue',
expect.objectContaining({
cache: 'no-cache',
headers: {
Authorization: 'Bearer firebase-token',
'Comfy-User': 'cloud-user'
}
}),
true
)
})
it('continues cloud fetches when auth header lookup fails', async () => {
mockAuthStore.getAuthHeader.mockRejectedValue(new Error('auth unavailable'))
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
const api = new ComfyApi()
await api.fetchApi('/history', {
headers: [['X-Test', '1']]
})
const [, options, retryOn401] = vi.mocked(fetchWithUnifiedRemint).mock
.calls[0]
expect(options.headers).toEqual([
['X-Test', '1'],
['Comfy-User', '']
])
expect(retryOn401).toBe(false)
expect(shouldRemintCloudRequest).not.toHaveBeenCalled()
expect(warn).toHaveBeenCalledWith(
'Failed to get auth header:',
expect.any(Error)
)
})
it('adds the cloud auth token to websocket URLs', async () => {
mockAuthStore.getAuthToken.mockResolvedValue('socket-token')
window.name = 'client-1'
const api = new ComfyApi()
api.init()
await vi.waitFor(() => {
expect(FakeWebSocket.instances).toHaveLength(1)
})
const socket = FakeWebSocket.instances[0]
expect(socket.url).toContain('clientId=client-1')
expect(socket.url).toContain('token=socket-token')
})
it('opens a cloud websocket without a token when token lookup fails', async () => {
mockAuthStore.getAuthToken.mockRejectedValue(new Error('token unavailable'))
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
const api = new ComfyApi()
api.init()
await vi.waitFor(() => {
expect(FakeWebSocket.instances).toHaveLength(1)
})
const socket = FakeWebSocket.instances[0]
expect(socket.url).not.toContain('token=')
expect(warn).toHaveBeenCalledWith(
'Could not get auth token for WebSocket connection:',
expect.any(Error)
)
})
})

View File

@@ -0,0 +1,460 @@
import axios from 'axios'
import { fromPartial } from '@total-typescript/shoehorn'
import { afterEach, describe, expect, it, vi } from 'vitest'
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
import type {
ComfyApiWorkflow,
ComfyWorkflowJSON
} from '@/platform/workflow/validation/schemas/workflowSchema'
import type { PromptResponse } from '@/schemas/apiSchema'
import { api as sharedApi, ComfyApi, PromptExecutionError } from '@/scripts/api'
import type { NodeExecutionId } from '@/types/nodeIdentification'
const fetchJobs = vi.hoisted(() => ({
fetchHistory: vi.fn(),
fetchJobDetail: vi.fn(),
fetchQueue: vi.fn()
}))
vi.mock('axios')
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => fetchJobs)
afterEach(() => {
vi.restoreAllMocks()
vi.clearAllMocks()
})
function jsonResponse(data: unknown, init: ResponseInit = {}) {
return new Response(JSON.stringify(data), {
status: 200,
headers: { 'content-type': 'application/json' },
...init
})
}
function promptData(): {
output: ComfyApiWorkflow
workflow: ComfyWorkflowJSON
} {
return {
output: fromPartial<ComfyApiWorkflow>({
1: {
inputs: {},
class_type: 'KSampler',
_meta: { title: 'KSampler' }
}
}),
workflow: fromPartial<ComfyWorkflowJSON>({
version: 0.4,
nodes: [],
links: []
})
}
}
function requestBody(fetchApi: ReturnType<typeof vi.spyOn>, call = 0) {
const init = fetchApi.mock.calls[call][1]
return JSON.parse(String(init?.body)) as Record<string, unknown>
}
describe('PromptExecutionError', () => {
it('formats string and node-specific prompt errors', () => {
const response = fromPartial<PromptResponse>({
error: 'invalid prompt',
node_errors: {
7: {
class_type: 'KSampler',
dependent_outputs: [],
errors: [{ message: 'bad seed', details: 'seed must be numeric' }]
}
}
})
expect(new PromptExecutionError(response, 400).toString()).toBe(
'invalid prompt\nKSampler:\n - bad seed: seed must be numeric'
)
})
it('formats structured prompt errors without node errors', () => {
const response = fromPartial<PromptResponse>({
error: {
type: 'prompt_outputs_failed_validation',
message: 'Validation failed',
details: 'missing node'
}
})
expect(new PromptExecutionError(response).toString()).toBe(
'Validation failed: missing node'
)
})
})
describe('ComfyApi queuePrompt', () => {
it('sends front queue requests with auth and execution options', async () => {
const api = new ComfyApi()
api.clientId = 'client-1'
api.authToken = 'auth-token'
api.apiKey = 'api-key'
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
await api.queuePrompt(-1, promptData(), {
partialExecutionTargets: ['9:10' as NodeExecutionId],
previewMethod: 'auto'
})
expect(fetchApi).toHaveBeenCalledWith('/prompt', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: fetchApi.mock.calls[0][1]?.body
})
expect(typeof fetchApi.mock.calls[0][1]?.body).toBe('string')
expect(requestBody(fetchApi)).toMatchObject({
client_id: 'client-1',
front: true,
partial_execution_targets: ['9:10'],
extra_data: {
auth_token_comfy_org: 'auth-token',
api_key_comfy_org: 'api-key',
comfy_usage_source: 'comfyui-frontend',
preview_method: 'auto'
}
})
expect(requestBody(fetchApi)).not.toHaveProperty('number')
})
it('omits default-only queue options and sets explicit queue numbers', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockImplementation(() =>
Promise.resolve(jsonResponse({ prompt_id: 'queued' }))
)
await api.queuePrompt(0, promptData(), { previewMethod: 'default' })
await api.queuePrompt(4, promptData())
expect(requestBody(fetchApi, 0)).toMatchObject({ client_id: '' })
expect(requestBody(fetchApi, 0)).not.toHaveProperty('front')
expect(requestBody(fetchApi, 0)).not.toHaveProperty('number')
expect(
requestBody(fetchApi, 0).extra_data as Record<string, unknown>
).not.toHaveProperty('preview_method')
expect(requestBody(fetchApi, 1)).toMatchObject({ number: 4 })
})
it('throws parsed prompt errors from non-200 responses', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi').mockResolvedValue(
jsonResponse(
{
error: {
type: 'server_error',
message: 'Server rejected prompt',
details: 'bad output'
}
},
{ status: 400, statusText: 'Bad Request' }
)
)
await expect(api.queuePrompt(0, promptData())).rejects.toThrow(
'Prompt execution failed'
)
})
it('wraps non-json prompt errors with status details', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi').mockResolvedValue(
new Response('backend exploded', {
status: 500,
statusText: 'Server Error'
})
)
await expect(api.queuePrompt(0, promptData())).rejects.toMatchObject({
status: 500,
response: {
error: {
message: '500 Server Error',
details: 'backend exploded'
}
}
})
})
})
describe('ComfyApi read helpers', () => {
it('returns localized templates, default templates, and empty non-json responses', async () => {
const api = new ComfyApi()
vi.mocked(axios.get)
.mockResolvedValueOnce({
headers: { 'content-type': 'application/json' },
data: [{ name: 'localized' }]
})
.mockResolvedValueOnce({
headers: { 'content-type': 'text/html' },
data: '<html></html>'
})
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
{ name: 'localized' }
])
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
expect(vi.mocked(axios.get).mock.calls[0][0]).toContain(
'/templates/index.fr.json'
)
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
'/templates/index.json'
)
})
it('falls back from missing localized templates to the default index', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.mocked(axios.get)
.mockRejectedValueOnce(new Error('missing locale'))
.mockResolvedValueOnce({
headers: { 'content-type': 'application/json' },
data: [{ name: 'default' }]
})
await expect(api.getCoreWorkflowTemplates('ja')).resolves.toEqual([
{ name: 'default' }
])
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
'/templates/index.json'
)
})
it('returns empty model lists for 404s and filters internal folders', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(
jsonResponse([
{ name: 'checkpoints' },
{ name: 'configs' },
{ name: 'custom_nodes' }
])
)
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(jsonResponse(['model.safetensors']))
await expect(api.getModelFolders()).resolves.toEqual([])
await expect(api.getModelFolders()).resolves.toEqual([
{ name: 'checkpoints' }
])
await expect(api.getModels('checkpoints')).resolves.toEqual([])
await expect(api.getModels('checkpoints')).resolves.toEqual([
'model.safetensors'
])
expect(fetchApi).toHaveBeenCalledTimes(4)
})
it('handles model metadata text responses', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'error').mockImplementation(() => {})
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(new Response(''))
.mockResolvedValueOnce(new Response('{"format":"safetensors"}'))
.mockResolvedValueOnce(
new Response('not json', { status: 200, statusText: 'OK' })
)
await expect(
api.viewMetadata('checkpoints', 'a.safetensors')
).resolves.toBe(null)
await expect(
api.viewMetadata('checkpoints', 'a.safetensors')
).resolves.toEqual({ format: 'safetensors' })
await expect(
api.viewMetadata('checkpoints', 'a.safetensors')
).resolves.toBe(null)
})
it('gets fuse options only from json responses', async () => {
const api = new ComfyApi()
vi.mocked(axios.get)
.mockResolvedValueOnce({
headers: { 'content-type': 'application/json' },
data: { keys: ['name'] }
})
.mockResolvedValueOnce({
headers: { 'content-type': 'text/plain' },
data: 'nope'
})
vi.spyOn(console, 'error').mockImplementation(() => {})
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
await expect(api.getFuseOptions()).resolves.toBeNull()
vi.mocked(axios.get).mockRejectedValueOnce(new Error('missing'))
await expect(api.getFuseOptions()).resolves.toBeNull()
})
})
describe('ComfyApi queue and data helpers', () => {
it('routes item collection requests to queue or history', async () => {
const api = new ComfyApi()
const queue = vi.spyOn(api, 'getQueue').mockResolvedValue({
Running: [],
Pending: []
})
const historyItem = fromPartial<JobListItem>({
id: 'history-1',
status: 'completed',
create_time: 1,
priority: 0
})
const history = vi.spyOn(api, 'getHistory').mockResolvedValue([historyItem])
await expect(api.getItems('queue')).resolves.toEqual({
Running: [],
Pending: []
})
await expect(api.getItems('history')).resolves.toEqual([historyItem])
expect(queue).toHaveBeenCalledOnce()
expect(history).toHaveBeenCalledOnce()
})
it('returns queue fallbacks unless errors are requested', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'error').mockImplementation(() => {})
fetchJobs.fetchQueue.mockRejectedValue(new Error('network'))
await expect(api.getQueue()).resolves.toEqual({ Running: [], Pending: [] })
await expect(api.getQueue({ throwOnError: true })).rejects.toThrow(
'network'
)
})
it('returns empty history when fetchHistory fails', async () => {
const api = new ComfyApi()
vi.spyOn(console, 'error').mockImplementation(() => {})
fetchJobs.fetchHistory.mockRejectedValue(new Error('history down'))
await expect(api.getHistory()).resolves.toEqual([])
})
it('posts item mutations with and without request bodies', async () => {
const api = new ComfyApi()
const fetchApi = vi.spyOn(api, 'fetchApi').mockResolvedValue(new Response())
await api.deleteItem('history', 'job-1')
await api.clearItems('queue')
await api.interrupt(null)
await api.interrupt('running-1')
expect(fetchApi.mock.calls.map((call) => call[0])).toEqual([
'/history',
'/queue',
'/interrupt',
'/interrupt'
])
expect(fetchApi.mock.calls[0][1]?.body).toBe(
JSON.stringify({ delete: ['job-1'] })
)
expect(fetchApi.mock.calls[1][1]?.body).toBe(
JSON.stringify({ clear: true })
)
expect(fetchApi.mock.calls[2][1]?.body).toBeUndefined()
expect(fetchApi.mock.calls[3][1]?.body).toBe(
JSON.stringify({ prompt_id: 'running-1' })
)
})
it('throws unauthorized settings responses', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi').mockResolvedValue(
new Response('', { status: 401, statusText: 'Unauthorized' })
)
await expect(api.getSettings()).rejects.toThrow('Unauthorized')
})
it('stores user data with default and raw-body options', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValue(new Response('', { status: 200 }))
const raw = new Blob(['raw'])
await api.storeUserData('a/b.json', { ok: true })
await api.storeUserData('raw.bin', raw, {
overwrite: false,
stringify: false,
throwOnError: false,
full_info: true
})
expect(fetchApi.mock.calls[0][0]).toBe(
'/userdata/a%2Fb.json?overwrite=true&full_info=false'
)
expect(fetchApi.mock.calls[0][1]?.body).toBe(JSON.stringify({ ok: true }))
expect(fetchApi.mock.calls[1][0]).toBe(
'/userdata/raw.bin?overwrite=false&full_info=true'
)
expect(fetchApi.mock.calls[1][1]?.body).toBe(raw)
})
it('honors storeUserData throwOnError', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(
new Response('', { status: 500, statusText: 'Server Error' })
)
.mockResolvedValueOnce(
new Response('', { status: 500, statusText: 'Server Error' })
)
await expect(api.storeUserData('bad.json', {})).rejects.toThrow(
"Error storing user data file 'bad.json': 500 Server Error"
)
await expect(
api.storeUserData('bad.json', {}, { throwOnError: false })
).resolves.toHaveProperty('status', 500)
})
it('lists full user data info by status', async () => {
const api = new ComfyApi()
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(
new Response('', { status: 500, statusText: 'Server Error' })
)
.mockResolvedValueOnce(jsonResponse([{ path: 'x' }]))
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([])
await expect(api.listUserDataFullInfo('models/')).rejects.toThrow(
"Error getting user data list 'models': 500 Server Error"
)
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([
{ path: 'x' }
])
})
it('loads global subgraph records and deferred data', async () => {
const fetchApi = vi
.spyOn(sharedApi, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(
jsonResponse({
ready: { name: 'Ready', info: { node_pack: 'core' }, data: '{}' },
lazy: { name: 'Lazy', info: { node_pack: 'core' } }
})
)
.mockResolvedValueOnce(jsonResponse({ data: '{"lazy":true}' }))
await expect(sharedApi.getGlobalSubgraphs()).resolves.toEqual({})
const subgraphs = await sharedApi.getGlobalSubgraphs()
expect(subgraphs.ready.data).toBe('{}')
expect(subgraphs.lazy.data).toBeInstanceOf(Promise)
await expect(subgraphs.lazy.data).resolves.toBe('{"lazy":true}')
expect(fetchApi).toHaveBeenCalledTimes(3)
})
})

829
src/scripts/api.test.ts Normal file
View File

@@ -0,0 +1,829 @@
import { fromAny } from '@total-typescript/shoehorn'
import axios from 'axios'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import {
api as singletonApi,
ComfyApi,
PromptExecutionError,
UnauthorizedError
} from '@/scripts/api'
import type { ComfyApiWorkflow } from '@/platform/workflow/validation/schemas/workflowSchema'
import type { NodeExecutionId } from '@/types/nodeIdentification'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
const { mockToastStore } = vi.hoisted(() => ({
mockToastStore: {
add: vi.fn()
}
}))
vi.mock('@/platform/auth/unified/remintRetry', () => ({
fetchWithUnifiedRemint: vi.fn(),
shouldRemintCloudRequest: vi.fn()
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: vi.fn(() => mockToastStore)
}))
vi.mock('axios', () => ({
default: {
get: vi.fn(),
patch: vi.fn()
}
}))
class FakeWebSocket extends EventTarget {
static instances: FakeWebSocket[] = []
binaryType = ''
sent: string[] = []
constructor(readonly url: string) {
super()
FakeWebSocket.instances.push(this)
}
send(data: string) {
this.sent.push(data)
}
close() {
this.dispatchEvent(new Event('close'))
}
}
function jsonResponse(data: unknown, init: ResponseInit = {}) {
return new Response(JSON.stringify(data), {
status: 200,
headers: { 'Content-Type': 'application/json' },
...init
})
}
function createWorkflow() {
return {
last_node_id: 0,
last_link_id: 0,
nodes: [],
links: [],
groups: [],
config: {},
extra: {},
version: 0.4
}
}
function binaryMessage(type: number, payload: Uint8Array) {
const bytes = new Uint8Array(4 + payload.length)
new DataView(bytes.buffer).setUint32(0, type)
bytes.set(payload, 4)
return bytes.buffer
}
function uint32(value: number) {
const bytes = new Uint8Array(4)
new DataView(bytes.buffer).setUint32(0, value)
return bytes
}
function concatBytes(...chunks: Uint8Array[]) {
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
const result = new Uint8Array(totalLength)
let offset = 0
for (const chunk of chunks) {
result.set(chunk, offset)
offset += chunk.length
}
return result
}
describe('PromptExecutionError', () => {
it('formats string, structured, and node-level errors', () => {
expect(
new PromptExecutionError({
error: 'Queue rejected',
node_errors: {}
}).toString()
).toBe('Queue rejected')
expect(
new PromptExecutionError({
error: {
type: 'invalid_prompt',
message: 'Invalid prompt',
details: 'missing input'
},
node_errors: {
1: {
class_type: 'PreviewAny',
dependent_outputs: ['1'],
errors: [
{
type: 'required_input_missing',
message: 'Required input',
details: 'source'
}
]
}
}
}).toString()
).toContain('Invalid prompt: missing input\nPreviewAny:')
})
})
describe('ComfyApi', () => {
beforeEach(() => {
vi.clearAllMocks()
FakeWebSocket.instances = []
window.name = ''
sessionStorage.clear()
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
})
afterEach(() => {
vi.restoreAllMocks()
vi.unstubAllGlobals()
})
it('builds API, internal, file, and fetch URLs with user headers', async () => {
const api = new ComfyApi()
api.user = 'reviewer'
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
jsonResponse({ ok: true })
)
await api.fetchApi('/queue', {
headers: new Headers([['X-Test', '1']])
})
expect(api.apiURL('/api/custom')).toBe(`${api.api_base}/api/custom`)
expect(api.apiURL('/queue')).toBe(`${api.api_base}/api/queue`)
expect(api.internalURL('/logs')).toBe(`${api.api_base}/internal/logs`)
expect(api.fileURL('/view')).toBe(`${api.api_base}/view`)
const [, options] = vi.mocked(fetchWithUnifiedRemint).mock.calls[0]
expect(options.headers).toBeInstanceOf(Headers)
expect((options.headers as Headers).get('Comfy-User')).toBe('reviewer')
expect((options.headers as Headers).get('X-Test')).toBe('1')
})
it('guards event listeners and still allows removing them', async () => {
const api = new ComfyApi()
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
const listener = vi.fn()
const throwingListener = vi.fn(() => {
throw new Error('listener failed')
})
const asyncListener = vi.fn(() => Promise.reject(new Error('async failed')))
const objectListener = { handleEvent: vi.fn() }
api.addEventListener('status', null)
api.removeEventListener('status', null)
api.addEventListener('status', listener)
api.addEventListener('status', throwingListener)
api.addEventListener('status', asyncListener)
api.addEventListener('status', fromAny(objectListener))
api.dispatchCustomEvent('status', { exec_info: { queue_remaining: 1 } })
await Promise.resolve()
expect(listener).toHaveBeenCalled()
expect(throwingListener).toHaveBeenCalled()
expect(asyncListener).toHaveBeenCalled()
expect(objectListener.handleEvent).toHaveBeenCalled()
expect(warn).toHaveBeenCalledTimes(2)
api.removeEventListener('status', listener)
api.dispatchCustomEvent('status', null)
expect(listener).toHaveBeenCalledTimes(1)
})
it('reuses guarded listener wrappers and ignores unknown removals', () => {
const api = new ComfyApi()
const listener = vi.fn()
const neverRegistered = vi.fn()
api.addEventListener('status', listener)
api.addEventListener('status', listener)
api.removeEventListener('status', neverRegistered)
api.dispatchCustomEvent('status', null)
expect(listener).toHaveBeenCalledTimes(1)
expect(neverRegistered).not.toHaveBeenCalled()
})
it('supports guarded custom event listeners', () => {
const api = new ComfyApi()
const listener = vi.fn()
api.addCustomEventListener('custom-node-event', listener)
;(api as EventTarget).dispatchEvent(
new CustomEvent('custom-node-event', { detail: { ok: true } })
)
api.removeCustomEventListener('custom-node-event', listener)
;(api as EventTarget).dispatchEvent(
new CustomEvent('custom-node-event', { detail: { ok: false } })
)
expect(listener).toHaveBeenCalledTimes(1)
})
it('routes websocket JSON messages and custom registered messages', () => {
window.name = 'existing-client'
const api = new ComfyApi()
const status = vi.fn()
const executing = vi.fn()
const featureFlags = vi.fn()
const custom = vi.fn()
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
vi.spyOn(console, 'log').mockImplementation(() => undefined)
api.addEventListener('status', status)
api.addEventListener('executing', executing)
api.addEventListener('feature_flags', featureFlags)
api.addCustomEventListener('custom-message', custom)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(new Event('open'))
expect(socket.url).toContain('clientId=existing-client')
expect(JSON.parse(socket.sent[0])).toMatchObject({
type: 'feature_flags'
})
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'status',
data: {
sid: 'fresh-client',
status: { exec_info: { queue_remaining: 2 } }
}
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'executing',
data: { node: '12' }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'feature_flags',
data: { supports_progress_text_metadata: true }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'custom-message',
data: { from: 'extension' }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'unknown-message',
data: {}
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'unknown-message',
data: {}
})
})
)
expect(api.clientId).toBe('fresh-client')
expect(window.name).toBe('fresh-client')
expect(sessionStorage.getItem('clientId')).toBe('fresh-client')
expect(status).toHaveBeenCalledWith(
expect.objectContaining({
detail: { exec_info: { queue_remaining: 2 } }
})
)
expect(executing).toHaveBeenCalledWith(
expect.objectContaining({ detail: '12' })
)
expect(featureFlags).toHaveBeenCalled()
expect(api.serverSupportsFeature('supports_progress_text_metadata')).toBe(
true
)
expect(custom).toHaveBeenCalledWith(
expect.objectContaining({ detail: { from: 'extension' } })
)
expect(api.reportedUnknownMessageTypes.has('unknown-message')).toBe(true)
expect(warn).toHaveBeenCalledTimes(1)
})
it('polls status when the initial websocket connection fails', async () => {
vi.useFakeTimers()
const api = new ComfyApi()
const status = vi.fn()
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(
jsonResponse({ exec_info: { queue_remaining: 4 } })
)
.mockRejectedValueOnce(new Error('poll failed'))
api.addEventListener('status', status)
api.init()
FakeWebSocket.instances[0].dispatchEvent(new Event('error'))
await vi.advanceTimersByTimeAsync(1000)
await vi.advanceTimersByTimeAsync(1000)
expect(status).toHaveBeenCalledWith(
expect.objectContaining({
detail: { exec_info: { queue_remaining: 4 } }
})
)
expect(status).toHaveBeenCalledWith(
expect.objectContaining({ detail: null })
)
vi.useRealTimers()
})
it('emits reconnect lifecycle events after an opened websocket closes', async () => {
vi.useFakeTimers()
const api = new ComfyApi()
const status = vi.fn()
const reconnecting = vi.fn()
const reconnected = vi.fn()
api.addEventListener('status', status)
api.addEventListener('reconnecting', reconnecting)
api.addEventListener('reconnected', reconnected)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(new Event('open'))
socket.close()
expect(status).toHaveBeenCalledWith(
expect.objectContaining({ detail: null })
)
expect(reconnecting).toHaveBeenCalledOnce()
await vi.advanceTimersByTimeAsync(300)
const reconnectSocket = FakeWebSocket.instances[1]
reconnectSocket.dispatchEvent(new Event('open'))
expect(reconnected).toHaveBeenCalledOnce()
vi.useRealTimers()
})
it('routes websocket variants without session ids and display-node fallbacks', () => {
const api = new ComfyApi()
const status = vi.fn()
const executing = vi.fn()
api.addEventListener('status', status)
api.addEventListener('executing', executing)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'status',
data: { status: undefined }
})
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: JSON.stringify({
type: 'executing',
data: { node: 'real', display_node: 'display' }
})
})
)
expect(status).toHaveBeenCalledWith(
expect.objectContaining({ detail: null })
)
expect(executing).toHaveBeenCalledWith(
expect.objectContaining({ detail: 'display' })
)
})
it('routes binary preview and progress websocket messages', () => {
const api = new ComfyApi()
const preview = vi.fn()
const previewWithMetadata = vi.fn()
const progressText = vi.fn()
const encoder = new TextEncoder()
api.serverFeatureFlags.value = {
supports_progress_text_metadata: true
}
api.addEventListener('b_preview', preview)
api.addEventListener('b_preview_with_metadata', previewWithMetadata)
api.addEventListener('progress_text', progressText)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(1, concatBytes(uint32(2), new Uint8Array([1, 2])))
})
)
const promptId = encoder.encode('prompt-1')
const nodeId = encoder.encode('7')
const progressPayload = concatBytes(
uint32(promptId.length),
promptId,
uint32(nodeId.length),
nodeId,
encoder.encode('loading')
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(3, progressPayload)
})
)
const metadata = encoder.encode(
JSON.stringify({
image_type: 'image/webp',
node_id: '7',
display_node_id: '7',
parent_node_id: '4',
real_node_id: '7',
prompt_id: 'prompt-1'
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(
4,
concatBytes(uint32(metadata.length), metadata, new Uint8Array([9]))
)
})
)
expect(preview).toHaveBeenCalledWith(
expect.objectContaining({
detail: expect.objectContaining({ type: 'image/png' })
})
)
expect(progressText).toHaveBeenCalledWith(
expect.objectContaining({
detail: {
nodeId: '7',
text: 'loading',
prompt_id: 'prompt-1'
}
})
)
expect(previewWithMetadata).toHaveBeenCalledWith(
expect.objectContaining({
detail: expect.objectContaining({
nodeId: '7',
parentNodeId: '4',
jobId: 'prompt-1'
})
})
)
})
it('routes binary jpeg/default previews and malformed binary messages defensively', () => {
const api = new ComfyApi()
const preview = vi.fn()
const progressText = vi.fn()
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
const encoder = new TextEncoder()
api.addEventListener('b_preview', preview)
api.addEventListener('progress_text', progressText)
api.init()
const socket = FakeWebSocket.instances[0]
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(1, concatBytes(uint32(1), new Uint8Array([1])))
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(1, concatBytes(uint32(99), new Uint8Array([2])))
})
)
const nodeId = encoder.encode('node')
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(
3,
concatBytes(uint32(nodeId.length), nodeId, encoder.encode('ready'))
)
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(3, new Uint8Array([1]))
})
)
socket.dispatchEvent(
new MessageEvent('message', {
data: binaryMessage(99, new Uint8Array())
})
)
expect(preview).toHaveBeenCalledWith(
expect.objectContaining({
detail: expect.objectContaining({ type: 'image/jpeg' })
})
)
expect(progressText).toHaveBeenCalledWith(
expect.objectContaining({
detail: { nodeId: 'node', text: 'ready' }
})
)
expect(warn).toHaveBeenCalled()
})
it('serializes prompt queue options and surfaces non-200 errors', async () => {
const api = new ComfyApi()
api.clientId = 'client-1'
api.authToken = 'token-1'
api.apiKey = 'key-1'
const prompt: ComfyApiWorkflow = {
1: {
class_type: 'PreviewAny',
inputs: {},
_meta: { title: 'PreviewAny' }
}
}
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ prompt_id: 'queued' }))
.mockResolvedValueOnce(
new Response('backend exploded', {
status: 500,
statusText: 'Server Error'
})
)
await expect(
api.queuePrompt(
-1,
{ output: prompt, workflow: createWorkflow() },
{
partialExecutionTargets: ['7' as NodeExecutionId],
previewMethod: 'latent2rgb'
}
)
).resolves.toEqual({ prompt_id: 'queued' })
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
expect(body).toMatchObject({
client_id: 'client-1',
prompt,
partial_execution_targets: ['7'],
front: true,
extra_data: {
auth_token_comfy_org: 'token-1',
api_key_comfy_org: 'key-1',
comfy_usage_source: 'comfyui-frontend',
preview_method: 'latent2rgb'
}
})
expect(body.number).toBeUndefined()
await expect(
api.queuePrompt(3, { output: prompt, workflow: createWorkflow() })
).rejects.toMatchObject({ status: 500 })
})
it('omits queue position and default preview method for normal queueing', async () => {
const api = new ComfyApi()
const prompt: ComfyApiWorkflow = {
1: {
class_type: 'PreviewAny',
inputs: {},
_meta: { title: 'PreviewAny' }
}
}
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
await api.queuePrompt(
0,
{ output: prompt, workflow: createWorkflow() },
{
previewMethod: 'default'
}
)
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
expect(body.front).toBeUndefined()
expect(body.number).toBeUndefined()
expect(body.extra_data.preview_method).toBeUndefined()
})
it('handles shareable assets, settings, userdata, subgraphs, and memory APIs', async () => {
const api = new ComfyApi()
const prompt: ComfyApiWorkflow = {
1: {
class_type: 'PreviewAny',
inputs: {},
_meta: { title: 'PreviewAny' }
}
}
vi.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ assets: [] }))
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(
new Response('', { status: 401, statusText: 'Unauthorized' })
)
.mockResolvedValueOnce(jsonResponse({}, { status: 204 }))
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
.mockResolvedValueOnce(
jsonResponse({}, { status: 500, statusText: 'Server Error' })
)
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
vi.spyOn(singletonApi, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ data: 'subgraph-data' }))
.mockResolvedValueOnce(jsonResponse({ missing: {} }))
.mockResolvedValueOnce(jsonResponse({ one: { data: 'inline' } }))
await expect(
api.getShareableAssets(prompt, { owned: false })
).resolves.toEqual({ assets: [] })
await expect(api.getShareableAssets(prompt)).rejects.toThrow(
'Failed to fetch shareable assets'
)
await expect(api.getSettings()).rejects.toBeInstanceOf(UnauthorizedError)
await expect(
api.storeUserData('plain.txt', 'raw', {
overwrite: false,
stringify: false,
throwOnError: false,
full_info: true
})
).resolves.toHaveProperty('status', 204)
await expect(api.listUserDataFullInfo('/missing/')).resolves.toEqual([])
await expect(api.listUserDataFullInfo('/broken/')).rejects.toThrow(
"Error getting user data list '/broken'"
)
await expect(api.getGlobalSubgraphData('one')).resolves.toBe(
'subgraph-data'
)
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
"Global subgraph 'missing' returned empty data"
)
await expect(api.getGlobalSubgraphs()).resolves.toEqual({
one: { data: 'inline' }
})
await api.freeMemory({ freeExecutionCache: true })
await api.freeMemory({ freeExecutionCache: false })
await api.freeMemory({ freeExecutionCache: true })
expect(mockToastStore.add).toHaveBeenCalledWith(
expect.objectContaining({
summary: 'Models and Execution Cache have been cleared.'
})
)
expect(mockToastStore.add).toHaveBeenCalledWith(
expect.objectContaining({ summary: 'Models have been unloaded.' })
)
expect(mockToastStore.add).toHaveBeenCalledWith(
expect.objectContaining({
summary:
'Unloading of models failed. Installed ComfyUI may be an outdated version.'
})
)
})
it('rejects non-success global subgraph data responses', async () => {
const api = new ComfyApi()
vi.spyOn(singletonApi, 'fetchApi').mockResolvedValueOnce(
jsonResponse({}, { status: 404, statusText: 'Not Found' })
)
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
"Failed to fetch global subgraph 'missing': 404 Not Found"
)
})
it('handles successful settings and userdata helper request shapes', async () => {
const api = new ComfyApi()
const fetchApi = vi
.spyOn(api, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({ theme: 'dark' }))
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
await expect(api.getSettings()).resolves.toEqual({ theme: 'dark' })
await expect(api.storeUserData('bad.json', { a: 1 })).rejects.toThrow(
"Error storing user data file 'bad.json'"
)
await api.moveUserData('old/path.json', 'new path.json')
await api.deleteUserData('old/path.json')
expect(fetchApi.mock.calls[2]).toEqual([
'/userdata/old%2Fpath.json/move/new%20path.json?overwrite=false',
{ method: 'POST' }
])
expect(fetchApi.mock.calls[3]).toEqual([
'/userdata/old%2Fpath.json',
{ method: 'DELETE' }
])
})
it('handles global subgraph fallbacks and log endpoints', async () => {
const api = new ComfyApi()
vi.spyOn(singletonApi, 'fetchApi')
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
.mockResolvedValueOnce(
jsonResponse({
missing: {
name: 'Missing data',
info: { node_pack: 'core' }
}
})
)
.mockResolvedValueOnce(jsonResponse({ data: 'lazy-data' }))
vi.mocked(axios.get)
.mockResolvedValueOnce({ data: 'log text', headers: {} })
.mockResolvedValueOnce({ data: { logs: [] }, headers: {} })
.mockRejectedValueOnce(new Error('no folders'))
.mockResolvedValueOnce({
data: { checkpoints: ['/models'] },
headers: {}
})
vi.mocked(axios.patch).mockResolvedValue({ data: undefined })
await expect(api.getGlobalSubgraphs()).resolves.toEqual({})
const subgraphs = await api.getGlobalSubgraphs()
await expect(subgraphs.missing.data).resolves.toBe('lazy-data')
await expect(api.getLogs()).resolves.toBe('log text')
await expect(api.getRawLogs()).resolves.toEqual({ logs: [] })
await api.subscribeLogs(true)
await expect(api.getFolderPaths()).resolves.toEqual({})
await expect(api.getFolderPaths()).resolves.toEqual({
checkpoints: ['/models']
})
expect(axios.patch).toHaveBeenCalledWith(
api.internalURL('/logs/subscribe'),
{ enabled: true, clientId: undefined }
)
})
it('loads localized template indexes and fuse options defensively', async () => {
const api = new ComfyApi()
vi.mocked(axios.get)
.mockResolvedValueOnce({
data: [{ name: 'template' }],
headers: { 'content-type': 'application/json' }
})
.mockResolvedValueOnce({
data: '<html></html>',
headers: { 'content-type': 'text/html' }
})
.mockRejectedValueOnce(new Error('missing locale'))
.mockResolvedValueOnce({
data: [{ name: 'fallback' }],
headers: { 'content-type': 'application/json' }
})
.mockRejectedValueOnce(new Error('default missing'))
.mockResolvedValueOnce({
data: { keys: ['name'] },
headers: { 'content-type': 'application/json' }
})
.mockResolvedValueOnce({
data: '<html></html>',
headers: { 'content-type': 'text/html' }
})
.mockResolvedValueOnce({
data: { ignored: true },
headers: {}
})
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
vi.spyOn(console, 'error').mockImplementation(() => undefined)
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
{ name: 'template' }
])
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
await expect(api.getCoreWorkflowTemplates('zh')).resolves.toEqual([
{ name: 'fallback' }
])
await expect(api.getCoreWorkflowTemplates('en')).resolves.toEqual([])
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
await expect(api.getFuseOptions()).resolves.toBeNull()
await expect(api.getFuseOptions()).resolves.toBeNull()
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,12 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type {
CanvasPointerEvent,
Subgraph
} from '@/lib/litegraph/src/litegraph'
import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
import type { ExecutedWsMessage } from '@/schemas/apiSchema'
const mockAssert = vi.hoisted(() => vi.fn())
@@ -14,7 +20,7 @@ const mockNodeOutputStore = vi.hoisted(() => ({
}))
const mockSubgraphNavigationStore = vi.hoisted(() => ({
exportState: vi.fn(() => []),
exportState: vi.fn((): string[] => []),
restoreState: vi.fn()
}))
@@ -23,10 +29,24 @@ const mockWorkflowStore = vi.hoisted(() => ({
getWorkflowByPath: vi.fn()
}))
const mockExecutionStore = vi.hoisted(() => ({
queuedJobs: {} as Record<string, { workflow: { changeTracker: unknown } }>
}))
const mockMaskEditorIsOpened = vi.hoisted(() => vi.fn(() => false))
vi.mock('@/scripts/app', () => ({
app: {
constructor: {
maskeditor_is_opended: mockMaskEditorIsOpened
},
graph: {},
ui: {
autoQueueEnabled: false,
autoQueueMode: 'instant'
},
rootGraph: {
subgraphs: new Map(),
serialize: vi.fn(() => ({
nodes: [],
links: [],
@@ -39,8 +59,10 @@ vi.mock('@/scripts/app', () => ({
}))
},
canvas: {
ds: { scale: 1, offset: [0, 0] }
}
ds: { scale: 1, offset: [0, 0] },
setGraph: vi.fn()
},
loadGraphData: vi.fn()
}
}))
@@ -65,6 +87,10 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
useWorkflowStore: vi.fn(() => mockWorkflowStore)
}))
vi.mock('@/stores/executionStore', () => ({
useExecutionStore: vi.fn(() => mockExecutionStore)
}))
import { app } from '@/scripts/app'
import { api } from '@/scripts/api'
import { ChangeTracker } from '@/scripts/changeTracker'
@@ -107,13 +133,120 @@ function mockCanvasState(state: ComfyWorkflowJSON) {
vi.mocked(app.rootGraph.serialize).mockReturnValue(state as never)
}
type ListenerMap = Record<string, EventListener[]>
function storeListener(
listeners: ListenerMap,
type: string,
listener: EventListenerOrEventListenerObject
) {
if (typeof listener === 'function') {
listeners[type] ??= []
listeners[type].push(listener)
}
}
function dispatchStored(listeners: ListenerMap, type: string, event: Event) {
for (const listener of listeners[type] ?? []) {
listener(event)
}
}
async function flushAsyncFrame() {
await Promise.resolve()
await Promise.resolve()
}
function getApiListener(name: string) {
const call = vi
.mocked(api.addEventListener)
.mock.calls.find(([eventName]) => eventName === name)
expect(call).toBeDefined()
return call?.[1] as (event: CustomEvent<ExecutedWsMessage>) => void
}
describe('ChangeTracker', () => {
beforeEach(() => {
vi.clearAllMocks()
nodeIdCounter = 0
ChangeTracker.isLoadingGraph = false
Reflect.set(ChangeTracker, '_checkStateWarned', false)
mockWorkflowStore.activeWorkflow = null
mockWorkflowStore.getWorkflowByPath.mockReturnValue(null)
mockExecutionStore.queuedJobs = {}
mockMaskEditorIsOpened.mockReturnValue(false)
app.ui.autoQueueEnabled = false
app.ui.autoQueueMode = 'instant'
vi.mocked(app.canvas.setGraph).mockClear()
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
app.rootGraph.subgraphs.clear()
app.canvas.ds.scale = 1
app.canvas.ds.offset = [0, 0]
})
describe('reset', () => {
it('updates initialState from activeState or an explicit state', () => {
const tracker = createTracker(createState(1))
const changed = createState(2)
tracker.activeState = changed
tracker.reset()
expect(tracker.initialState).toEqual(changed)
expect(tracker.initialState).not.toBe(changed)
const explicit = createState(3)
tracker.reset(explicit)
expect(tracker.activeState).toEqual(explicit)
expect(tracker.activeState).not.toBe(explicit)
expect(tracker.initialState).toEqual(explicit)
})
it('does not reset while restoring state', () => {
const tracker = createTracker(createState(1))
const original = tracker.initialState
tracker._restoringState = true
tracker.reset(createState(2))
expect(tracker.initialState).toBe(original)
})
})
describe('restore', () => {
it('restores viewport, outputs, and root graph navigation', () => {
const tracker = createTracker()
app.canvas.ds.scale = 2
app.canvas.ds.offset = [10, 20]
mockNodeOutputStore.snapshotOutputs.mockReturnValue({ 1: { images: [] } })
mockSubgraphNavigationStore.exportState.mockReturnValue([])
tracker.store()
app.canvas.ds.scale = 1
app.canvas.ds.offset = [0, 0]
tracker.restore()
expect(app.canvas.ds.scale).toBe(2)
expect(app.canvas.ds.offset).toEqual([10, 20])
expect(mockNodeOutputStore.restoreOutputs).toHaveBeenCalledWith({
1: { images: [] }
})
expect(mockSubgraphNavigationStore.restoreState).toHaveBeenCalledWith([])
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
})
it('restores saved subgraph navigation when the subgraph exists', () => {
const tracker = createTracker()
const subgraph = { id: 'subgraph-1' } as unknown as Subgraph
app.rootGraph.subgraphs.set('subgraph-1', subgraph)
mockSubgraphNavigationStore.exportState.mockReturnValue(['subgraph-1'])
tracker.store()
tracker.restore()
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
})
})
describe('captureCanvasState', () => {
@@ -169,9 +302,32 @@ describe('ChangeTracker', () => {
expect.stringContaining('captureCanvasState')
)
})
it('reports inactive tracker calls only once for the same workflow', () => {
const tracker = createTracker()
tracker.workflow.path = '/test/dedupe-workflow.json'
mockWorkflowStore.activeWorkflow = { changeTracker: {} }
tracker.captureCanvasState()
tracker.captureCanvasState()
expect(mockAssert).toHaveBeenCalledOnce()
})
})
describe('state capture', () => {
it('sets the active state without pushing undo when none exists yet', () => {
const tracker = createTracker(createState(1))
const changed = createState(2)
tracker.activeState = undefined as never
mockCanvasState(changed)
tracker.captureCanvasState()
expect(tracker.activeState).toEqual(changed)
expect(tracker.undoQueue).toHaveLength(0)
})
it('pushes to undoQueue, updates activeState, and calls updateModified', () => {
const initial = createState(1)
const tracker = createTracker(initial)
@@ -238,6 +394,19 @@ describe('ChangeTracker', () => {
expect(tracker.undoQueue).toHaveLength(ChangeTracker.MAX_HISTORY)
})
it('does not capture until the outer change transaction finishes', () => {
const tracker = createTracker(createState(1))
tracker.beforeChange()
tracker.beforeChange()
mockCanvasState(createState(2))
tracker.afterChange()
expect(app.rootGraph.serialize).not.toHaveBeenCalled()
tracker.afterChange()
expect(app.rootGraph.serialize).toHaveBeenCalledOnce()
})
})
})
@@ -302,6 +471,105 @@ describe('ChangeTracker', () => {
})
})
describe('updateModified', () => {
it('updates workflow modified state when the store can find it', () => {
const state = createState(1)
const tracker = createTracker(state)
const workflow = { isModified: true }
mockWorkflowStore.getWorkflowByPath.mockReturnValue(workflow)
tracker.updateModified()
expect(workflow.isModified).toBe(false)
tracker.activeState = createState(2)
tracker.updateModified()
expect(workflow.isModified).toBe(true)
})
})
describe('undo and redo', () => {
it('restores previous state and moves the current state to the target queue', async () => {
const initial = createState(1)
const changed = createState(2)
const tracker = createTracker(changed)
tracker.undoQueue.push(initial)
await tracker.undo()
expect(app.loadGraphData).toHaveBeenCalledWith(
initial,
false,
false,
tracker.workflow,
{
checkForRerouteMigration: false,
silentAssetErrors: true
}
)
expect(tracker.activeState).toBe(initial)
expect(tracker.redoQueue).toEqual([changed])
expect(tracker._restoringState).toBe(false)
})
it('clears restoring state when loading fails', async () => {
const tracker = createTracker(createState(2))
tracker.undoQueue.push(createState(1))
vi.mocked(app.loadGraphData).mockRejectedValueOnce(
new Error('load failed')
)
await expect(tracker.undo()).rejects.toThrow('load failed')
expect(tracker._restoringState).toBe(false)
})
it('does nothing when no previous state exists', async () => {
const tracker = createTracker(createState(1))
await tracker.undo()
expect(app.loadGraphData).not.toHaveBeenCalled()
})
it('handles keyboard undo and redo shortcuts', async () => {
const tracker = createTracker()
const undo = vi.spyOn(tracker, 'undo').mockResolvedValue()
const redo = vi.spyOn(tracker, 'redo').mockResolvedValue()
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
)
).resolves.toBe(true)
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', {
key: 'z',
ctrlKey: true,
shiftKey: true
})
)
).resolves.toBe(true)
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', { key: 'y', metaKey: true })
)
).resolves.toBe(true)
await expect(
tracker.undoRedo(
new KeyboardEvent('keydown', {
key: 'z',
ctrlKey: true,
altKey: true
})
)
).resolves.toBeUndefined()
expect(undo).toHaveBeenCalledOnce()
expect(redo).toHaveBeenCalledTimes(2)
})
})
describe('checkState (deprecated)', () => {
it('delegates to captureCanvasState', () => {
const tracker = createTracker(createState(1))
@@ -312,5 +580,389 @@ describe('ChangeTracker', () => {
expect(tracker.activeState).toEqual(changed)
})
it('warns only once before delegating', () => {
const tracker = createTracker(createState(1))
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
tracker.checkState()
tracker.checkState()
expect(warn).toHaveBeenCalledOnce()
warn.mockRestore()
})
})
describe('bindInput', () => {
it('returns false for missing canvas or body elements', () => {
expect(ChangeTracker.bindInput(null)).toBe(false)
expect(ChangeTracker.bindInput(document.createElement('canvas'))).toBe(
false
)
expect(ChangeTracker.bindInput(document.body)).toBe(false)
})
it('captures state once when an input-like element changes', () => {
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
const input = document.createElement('input')
expect(ChangeTracker.bindInput(input)).toBe(true)
input.dispatchEvent(new Event('change'))
input.dispatchEvent(new Event('change'))
expect(capture).toHaveBeenCalledOnce()
})
it('binds textarea-like elements that expose an input handler slot', () => {
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
const element = document.createElement('div') as HTMLElement & {
oninput: unknown
}
element.oninput = null
expect(ChangeTracker.bindInput(element)).toBe(true)
element.dispatchEvent(new Event('change'))
expect(capture).toHaveBeenCalledOnce()
})
})
describe('init', () => {
it('captures changes from registered browser, graph, and API events', async () => {
const windowListeners: ListenerMap = {}
const documentListeners: ListenerMap = {}
const windowAddSpy = vi
.spyOn(window, 'addEventListener')
.mockImplementation((type, listener) => {
storeListener(windowListeners, type, listener)
})
const documentAddSpy = vi
.spyOn(document, 'addEventListener')
.mockImplementation((type, listener) => {
storeListener(documentListeners, type, listener)
})
const rafSpy = vi
.spyOn(window, 'requestAnimationFrame')
.mockImplementation((callback) => {
callback(0)
return 1
})
const processMouseUp = vi.fn(() => true)
const prompt = vi.fn()
const close = vi.fn(() => true)
const originalProcessMouseUp = LGraphCanvas.prototype.processMouseUp
const originalPrompt = LGraphCanvas.prototype.prompt
const originalClose = LiteGraph.ContextMenu.prototype.close
LGraphCanvas.prototype.processMouseUp = processMouseUp
LGraphCanvas.prototype.prompt = prompt
LiteGraph.ContextMenu.prototype.close = close
try {
ChangeTracker.init()
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
dispatchStored(windowListeners, 'mouseup', new MouseEvent('mouseup'))
getApiListener('promptQueued')(
new CustomEvent('promptQueued', {
detail: {} as unknown as ExecutedWsMessage
})
)
getApiListener('graphCleared')(
new CustomEvent('graphCleared', {
detail: {} as unknown as ExecutedWsMessage
})
)
dispatchStored(
documentListeners,
'litegraph:canvas',
new CustomEvent('litegraph:canvas', {
detail: { subType: 'before-change' }
})
)
dispatchStored(
documentListeners,
'litegraph:canvas',
new CustomEvent('litegraph:canvas', {
detail: { subType: 'after-change' }
})
)
expect(capture).toHaveBeenCalledTimes(4)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'Control' })
)
await flushAsyncFrame()
dispatchStored(windowListeners, 'keyup', new KeyboardEvent('keyup'))
expect(capture).toHaveBeenCalledTimes(5)
const undoRedo = vi.spyOn(tracker, 'undoRedo').mockResolvedValue(true)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
)
await flushAsyncFrame()
expect(undoRedo).toHaveBeenCalledOnce()
expect(capture).toHaveBeenCalledTimes(5)
undoRedo.mockResolvedValue(undefined)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'a' })
)
await flushAsyncFrame()
expect(capture).toHaveBeenCalledTimes(6)
const input = document.createElement('input')
document.body.append(input)
input.focus()
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'b' })
)
await flushAsyncFrame()
input.remove()
expect(capture).toHaveBeenCalledTimes(6)
mockMaskEditorIsOpened.mockReturnValue(true)
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'c' })
)
await flushAsyncFrame()
expect(capture).toHaveBeenCalledTimes(6)
const canvas = {} as LGraphCanvas
LGraphCanvas.prototype.processMouseUp.call(
canvas,
new MouseEvent('mouseup') as CanvasPointerEvent
)
expect(processMouseUp).toHaveBeenCalledOnce()
expect(capture).toHaveBeenCalledTimes(7)
const promptCallback = vi.fn()
LGraphCanvas.prototype.prompt.call(
canvas,
'title',
'value',
promptCallback,
new MouseEvent('mouseup') as CanvasPointerEvent
)
const extendedCallback = prompt.mock.calls[0]?.[2] as
| ((value: string) => void)
| undefined
extendedCallback?.('updated')
expect(promptCallback).toHaveBeenCalledWith('updated')
expect(capture).toHaveBeenCalledTimes(8)
LiteGraph.ContextMenu.prototype.close.call(
{} as InstanceType<typeof LiteGraph.ContextMenu>,
new MouseEvent('mouseup')
)
expect(close).toHaveBeenCalledOnce()
expect(capture).toHaveBeenCalledTimes(9)
} finally {
LGraphCanvas.prototype.processMouseUp = originalProcessMouseUp
LGraphCanvas.prototype.prompt = originalPrompt
LiteGraph.ContextMenu.prototype.close = originalClose
windowAddSpy.mockRestore()
documentAddSpy.mockRestore()
rafSpy.mockRestore()
}
})
it('ignores repeat keydowns and missing active trackers', async () => {
const windowListeners: ListenerMap = {}
const windowAddSpy = vi
.spyOn(window, 'addEventListener')
.mockImplementation((type, listener) => {
storeListener(windowListeners, type, listener)
})
const documentAddSpy = vi
.spyOn(document, 'addEventListener')
.mockImplementation(() => undefined)
const rafSpy = vi
.spyOn(window, 'requestAnimationFrame')
.mockImplementation((callback) => {
callback(0)
return 1
})
try {
ChangeTracker.init()
const tracker = createTracker()
const capture = vi.spyOn(tracker, 'captureCanvasState')
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'x', repeat: true })
)
await flushAsyncFrame()
expect(capture).not.toHaveBeenCalled()
mockWorkflowStore.activeWorkflow = null
dispatchStored(
windowListeners,
'keydown',
new KeyboardEvent('keydown', { key: 'x' })
)
await flushAsyncFrame()
expect(capture).not.toHaveBeenCalled()
} finally {
windowAddSpy.mockRestore()
documentAddSpy.mockRestore()
rafSpy.mockRestore()
}
})
it('stores executed outputs for the workflow that owns the prompt', () => {
ChangeTracker.init()
const tracker = createTracker()
const executed = getApiListener('executed')
mockExecutionStore.queuedJobs = {
promptA: { workflow: { changeTracker: tracker } }
}
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
output: { images: ['first'] }
} as unknown as ExecutedWsMessage
})
)
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
merge: true,
output: { images: ['second'], text: ['caption'] }
} as unknown as ExecutedWsMessage
})
)
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'missing',
node: '2',
output: { images: ['ignored'] }
} as unknown as ExecutedWsMessage
})
)
expect(tracker.nodeOutputs).toEqual({
1: { images: ['first', 'second'], text: ['caption'] }
})
})
it('replaces non-array executed outputs during merge updates', () => {
ChangeTracker.init()
const tracker = createTracker()
const executed = getApiListener('executed')
mockExecutionStore.queuedJobs = {
promptA: { workflow: { changeTracker: tracker } }
}
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
output: { value: 'old' }
} as unknown as ExecutedWsMessage
})
)
executed(
new CustomEvent('executed', {
detail: {
prompt_id: 'promptA',
node: '1',
merge: true,
output: { value: 'new' }
} as unknown as ExecutedWsMessage
})
)
expect(tracker.nodeOutputs).toEqual({
1: { value: 'new' }
})
})
})
describe('graphEqual', () => {
it('compares workflow nodes as an unordered set and ignores extra.ds', () => {
const first = createState(2)
const second = {
...createState(),
nodes: [...first.nodes].reverse(),
links: first.links,
groups: first.groups,
extra: { ds: { scale: 2 } }
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, first)).toBe(true)
expect(ChangeTracker.graphEqual(first, second)).toBe(true)
})
it('returns false for non-object values and meaningful graph differences', () => {
const first = createState(1)
const differentNodes = createState(2)
const differentLinks = {
...first,
links: [[1, 1, 0, 2, 0, 'MODEL']]
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, null as never)).toBe(false)
expect(ChangeTracker.graphEqual(first, differentNodes)).toBe(false)
expect(ChangeTracker.graphEqual(first, differentLinks)).toBe(false)
})
it('returns false for extra properties other than viewport state', () => {
const first = createState()
const second = {
...first,
extra: { custom: true }
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
})
it.each([
'floatingLinks',
'reroutes',
'groups',
'definitions',
'subgraphs'
] as const)('returns false when %s differs', (key) => {
const first = createState()
const second = {
...first,
[key]: [{ id: 1 }]
} as unknown as ComfyWorkflowJSON
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
})
})
})

View File

@@ -1,11 +1,24 @@
import { describe, expect, test, vi } from 'vitest'
import type { LGraph } from '@/lib/litegraph/src/litegraph'
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
import { ComponentWidgetImpl, DOMWidgetImpl } from '@/scripts/domWidget'
import {
addWidget,
ComponentWidgetImpl,
DOMWidgetImpl,
isComponentWidget,
isDOMWidget
} from '@/scripts/domWidget'
const { registerWidget, unregisterWidget } = vi.hoisted(() => ({
registerWidget: vi.fn(),
unregisterWidget: vi.fn()
}))
vi.mock('@/stores/domWidgetStore', () => ({
useDomWidgetStore: () => ({
unregisterWidget: vi.fn()
registerWidget,
unregisterWidget
})
}))
@@ -24,13 +37,11 @@ describe('DOMWidget Y Position Preservation', () => {
options: {}
})
// Set a specific Y position
originalWidget.y = 66
const newNode = new LGraphNode('new-node')
const clonedWidget = originalWidget.createCopyForNode(newNode)
// Verify Y position is preserved
expect(clonedWidget.y).toBe(66)
expect(clonedWidget.node).toBe(newNode)
expect(clonedWidget.name).toBe('test-widget')
@@ -48,13 +59,11 @@ describe('DOMWidget Y Position Preservation', () => {
options: {}
})
// Set a specific Y position
originalWidget.y = 42
const newNode = new LGraphNode('new-node')
const clonedWidget = originalWidget.createCopyForNode(newNode)
// Verify Y position is preserved
expect(clonedWidget.y).toBe(42)
expect(clonedWidget.node).toBe(newNode)
expect(clonedWidget.element).toBe(mockElement)
@@ -71,11 +80,9 @@ describe('DOMWidget Y Position Preservation', () => {
options: {}
})
// Don't explicitly set Y (should be 0 by default)
const newNode = new LGraphNode('new-node')
const clonedWidget = originalWidget.createCopyForNode(newNode)
// Verify Y position is preserved (should be 0)
expect(clonedWidget.y).toBe(0)
})
})
@@ -96,3 +103,271 @@ describe('BaseDOMWidgetImpl.isVisible', () => {
expect(widget.isVisible()).toBe(false)
})
})
describe('DOMWidgetImpl', () => {
test('identifies DOM and component widgets', () => {
const node = new LGraphNode('test-node')
const domWidget = new DOMWidgetImpl({
node,
name: 'dom',
type: 'text',
element: document.createElement('textarea'),
options: {}
})
const componentWidget = new ComponentWidgetImpl({
node,
name: 'component',
component: { template: '<div />' },
inputSpec: { name: 'component', type: 'STRING' },
options: {}
})
expect(isDOMWidget(domWidget)).toBe(true)
expect(isDOMWidget(componentWidget)).toBe(false)
expect(isComponentWidget(componentWidget)).toBe(true)
expect(isComponentWidget(domWidget)).toBe(false)
})
test('uses option-backed values, callbacks, and margins', () => {
const node = new LGraphNode('test-node')
let value = 'initial'
const setValue = vi.fn((next: string) => {
value = next
})
const callback = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
getValue: () => value,
setValue,
margin: 4
}
})
widget.callback = callback
widget.value = 'next'
expect(widget.value).toBe('next')
expect(widget.margin).toBe(4)
expect(setValue).toHaveBeenCalledWith('next')
expect(callback).toHaveBeenCalledWith('next')
})
test('uses default value and margin when options do not provide them', () => {
const node = new LGraphNode('test-node')
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {}
})
expect(widget.value).toBe('')
expect(widget.margin).toBe(10)
})
test('draws zoom placeholders and delegates visible draws', () => {
const node = new LGraphNode('test-node')
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(true)
const onDraw = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
hideOnZoom: true,
margin: 5,
onDraw
}
})
const ctx = {
beginPath: vi.fn(),
fill: vi.fn(),
fillStyle: '#000',
rect: vi.fn()
} as unknown as CanvasRenderingContext2D
widget.draw(ctx, node, 100, 10, 40, true)
expect(ctx.rect).toHaveBeenCalledWith(5, 15, 90, 30)
expect(ctx.fill).toHaveBeenCalledOnce()
expect(ctx.fillStyle).toBe('#000')
expect(onDraw).toHaveBeenCalledWith(widget)
})
test('skips placeholder drawing when hidden', () => {
const node = new LGraphNode('test-node')
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(false)
const onDraw = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
hideOnZoom: true,
onDraw
}
})
const ctx = {
beginPath: vi.fn(),
fill: vi.fn(),
fillStyle: '#000',
rect: vi.fn()
} as unknown as CanvasRenderingContext2D
widget.draw(ctx, node, 100, 10, 40, true)
expect(ctx.rect).not.toHaveBeenCalled()
expect(onDraw).toHaveBeenCalledWith(widget)
})
test('computes hidden, option, percent, and fallback layout sizes', () => {
const node = new LGraphNode('test-node')
node.size = [100, 200]
const hiddenWidget = new DOMWidgetImpl({
node,
name: 'hidden',
type: 'hidden',
element: document.createElement('textarea'),
options: {}
})
const optionWidget = new DOMWidgetImpl({
node,
name: 'option',
type: 'text',
element: document.createElement('textarea'),
options: {
getMinHeight: () => 11,
getMaxHeight: () => 88,
getHeight: () => 44
}
})
const percentWidget = new DOMWidgetImpl({
node,
name: 'percent',
type: 'text',
element: document.createElement('textarea'),
options: {
getMinHeight: () => 10,
getMaxHeight: () => 60,
getHeight: () => '25%'
}
})
const fallbackWidget = new DOMWidgetImpl({
node,
name: 'fallback',
type: 'text',
element: document.createElement('textarea'),
options: {
getHeight: () => 40
}
})
expect(hiddenWidget.computeLayoutSize(node)).toEqual({
minHeight: 0,
maxHeight: 0,
minWidth: 0
})
expect(optionWidget.computeLayoutSize(node)).toEqual({
minHeight: 11,
maxHeight: 88,
minWidth: 0
})
expect(percentWidget.computeLayoutSize(node)).toEqual({
minHeight: 10,
maxHeight: 60,
minWidth: 0
})
expect(fallbackWidget.computeLayoutSize(node)).toEqual({
minHeight: 40,
maxHeight: undefined,
minWidth: 0
})
})
test('registers widgets immediately and through node lifecycle callbacks', () => {
registerWidget.mockClear()
unregisterWidget.mockClear()
const node = new LGraphNode('test-node')
node.graph = {} as LGraph
const beforeResize = vi.fn()
const afterResize = vi.fn()
const widget = new DOMWidgetImpl({
node,
name: 'text',
type: 'text',
element: document.createElement('textarea'),
options: {
beforeResize,
afterResize
}
})
vi.spyOn(node, 'addCustomWidget')
addWidget(node, widget)
node.onAdded?.(node.graph)
node.onResize?.([0, 0])
node.onRemoved?.()
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
expect(registerWidget).toHaveBeenCalledWith(widget)
expect(registerWidget).toHaveBeenCalledTimes(2)
expect(beforeResize).toHaveBeenCalledWith(node)
expect(afterResize).toHaveBeenCalledWith(node)
expect(unregisterWidget).toHaveBeenCalledWith(widget.id)
})
test('computes component layout and serializes raw values', () => {
const node = new LGraphNode('test-node')
const value = { nested: true }
const widget = new ComponentWidgetImpl({
node,
name: 'component',
component: { template: '<div />' },
inputSpec: { name: 'component', type: 'STRING' },
options: {
getValue: () => value,
getMinHeight: () => 12,
getMaxHeight: () => 48
}
})
expect(widget.computeLayoutSize()).toEqual({
minHeight: 12,
maxHeight: 48,
minWidth: 0
})
expect(widget.serializeValue()).toEqual({ nested: true })
})
test('adds DOM widgets through LGraphNode prototype helper', () => {
const node = new LGraphNode('test-node')
const element = document.createElement('textarea')
let value = 'initial'
const setValue = vi.fn((next: string) => {
value = next
})
vi.spyOn(node, 'addCustomWidget')
const widget = node.addDOMWidget('text', 'textarea', element, {
getValue: () => value,
setValue
})
const callback = vi.fn()
widget.callback = callback
widget.value = 'next'
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
expect(widget.element).toBe(element)
expect(widget.options.hideOnZoom).toBe(true)
expect(widget.value).toBe('next')
expect(setValue).toHaveBeenCalledWith('next')
expect(callback).toHaveBeenCalledWith('next')
})
})

View File

@@ -0,0 +1,106 @@
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
import { describe, expect, it, vi } from 'vitest'
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
interface TestWidgetOptions {
name: string
type: string
default?: unknown
multiline?: boolean
}
function createWidgetFactory() {
return (node: LGraphNode, options: TestWidgetOptions): IBaseWidget => {
const widget: IBaseWidget = fromAny({
name: options.name,
type: options.type,
value: options.default,
options
})
node.widgets = [...(node.widgets ?? []), widget]
return widget
}
}
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
() => ({
useStringWidget: () => createWidgetFactory()
})
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
() => ({
useFloatWidget: () => createWidgetFactory()
})
)
vi.mock(
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
() => ({
useBooleanWidget: () => createWidgetFactory()
})
)
import './errorNodeWidgets'
describe('errorNodeWidgets', () => {
it('restores widgets from serialized values on error nodes', () => {
const node = new LGraphNode('BrokenNode')
const longText = 'serialized value with more than twenty chars'
node.has_errors = true
node.onConfigure?.(
fromAny({
widgets_values: {
length: 5,
0: 'short text',
1: longText,
2: 12,
3: true,
4: { nested: 'value' }
}
})
)
expect(node.widgets).toHaveLength(5)
expect(node.widgets?.map((widget) => widget.name)).toEqual([
'UNKNOWN',
'UNKNOWN_1',
'UNKNOWN_2',
'UNKNOWN_3',
'UNKNOWN_4'
])
expect(node.widgets?.map((widget) => widget.label)).toEqual([
'UNKNOWN',
'UNKNOWN',
'UNKNOWN',
'UNKNOWN',
'UNKNOWN'
])
expect(node.widgets?.map((widget) => widget.value)).toEqual([
'short text',
longText,
12,
true,
'{"nested":"value"}'
])
expect(node.serialize_widgets).toBe(true)
})
it('leaves normal nodes unchanged', () => {
const node = new LGraphNode('HealthyNode')
node.onConfigure?.(
fromPartial({
widgets_values: ['ignored']
})
)
expect(node.widgets).toBeUndefined()
expect(node.serialize_widgets).toBeUndefined()
})
})

View File

@@ -7,7 +7,8 @@ import {
EXPECTED_PROMPT_NAN_COERCED,
EXPECTED_WORKFLOW,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromAvifFile } from './avif'
@@ -83,6 +84,11 @@ describe('AVIF metadata', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getFromAvifFile(file)).toEqual({})
})
it('resolves empty when the FileReader load has no result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getFromAvifFile(file)).toEqual({})
})
})
})
@@ -139,7 +145,8 @@ const buildInfeBox = (
itemType: string,
version = 2
): Uint8Array => {
const bodySize = 4 + 2 + 2 + 4 + 1 + 1
const itemIdSize = version === 2 ? 2 : version >= 3 ? 4 : 0
const bodySize = 4 + itemIdSize + (version >= 2 ? 2 + 4 + 1 + 1 : 0)
const totalSize = 8 + bodySize
const buf = new Uint8Array(totalSize)
const dv = new DataView(buf.buffer)
@@ -147,22 +154,36 @@ const buildInfeBox = (
buf.set(new TextEncoder().encode('infe'), 4)
buf[8] = version
if (version >= 2) {
setU16BE(dv, 12, itemId)
setU16BE(dv, 14, 0)
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), 16)
let p = 12
if (version === 2) {
setU16BE(dv, p, itemId)
p += 2
} else {
setU32BE(dv, p, itemId)
p += 4
}
setU16BE(dv, p, 0)
p += 2
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), p)
}
return buf
}
const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
const bodySize = 4 + 2 + infeBoxes.reduce((s, b) => s + b.length, 0)
const buildIinfBox = (infeBoxes: Uint8Array[], version = 0): Uint8Array => {
const countSize = version === 0 ? 2 : 4
const bodySize = 4 + countSize + infeBoxes.reduce((s, b) => s + b.length, 0)
const totalSize = 8 + bodySize
const buf = new Uint8Array(totalSize)
const dv = new DataView(buf.buffer)
setU32BE(dv, 0, totalSize)
buf.set(new TextEncoder().encode('iinf'), 4)
setU16BE(dv, 12, infeBoxes.length)
let off = 14
buf[8] = version
if (version === 0) {
setU16BE(dv, 12, infeBoxes.length)
} else {
setU32BE(dv, 12, infeBoxes.length)
}
let off = 8 + 4 + countSize
for (const ib of infeBoxes) {
buf.set(ib, off)
off += ib.length
@@ -170,31 +191,91 @@ const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
return buf
}
interface IlocItem {
itemId: number
extentOffset: number
extentLength: number
extents?: Array<{ extentOffset: number; extentLength: number }>
}
const buildIlocBox = (
items: { itemId: number; extentOffset: number; extentLength: number }[]
items: IlocItem[],
{
version = 0,
baseOffsetSize = 0,
indexSize = 0
}: { version?: number; baseOffsetSize?: number; indexSize?: number } = {}
): Uint8Array => {
const perItemSize = 2 + 2 + 0 + 2 + (4 + 4)
const bodySize = 4 + 1 + 1 + 2 + items.length * perItemSize
const itemCountSize = version < 2 ? 2 : 4
const itemIdSize = version < 2 ? 2 : 4
const constructionMethodSize = version === 1 || version === 2 ? 2 : 0
const itemSizes = items.map((item) => {
const extents = item.extents ?? [
{
extentOffset: item.extentOffset,
extentLength: item.extentLength
}
]
return (
itemIdSize +
constructionMethodSize +
2 +
baseOffsetSize +
2 +
extents.length * (indexSize + 4 + 4)
)
})
const bodySize =
4 + 1 + 1 + itemCountSize + itemSizes.reduce((sum, size) => sum + size, 0)
const totalSize = 8 + bodySize
const buf = new Uint8Array(totalSize)
const dv = new DataView(buf.buffer)
setU32BE(dv, 0, totalSize)
buf.set(new TextEncoder().encode('iloc'), 4)
buf[8] = version
buf[12] = 0x44
buf[13] = 0x00
setU16BE(dv, 14, items.length)
let p = 16
for (const it of items) {
setU16BE(dv, p, it.itemId)
buf[13] = (baseOffsetSize << 4) | indexSize
let p = 14
if (version < 2) {
setU16BE(dv, p, items.length)
p += 2
} else {
setU32BE(dv, p, items.length)
p += 4
}
for (const it of items) {
if (version < 2) {
setU16BE(dv, p, it.itemId)
p += 2
} else {
setU32BE(dv, p, it.itemId)
p += 4
}
if (version === 1 || version === 2) {
setU16BE(dv, p, 0)
p += 2
}
setU16BE(dv, p, 0)
p += 2
setU16BE(dv, p, 1)
if (baseOffsetSize > 0) {
setU32BE(dv, p, 0)
p += baseOffsetSize
}
const extents = it.extents ?? [
{
extentOffset: it.extentOffset,
extentLength: it.extentLength
}
]
setU16BE(dv, p, extents.length)
p += 2
setU32BE(dv, p, it.extentOffset)
p += 4
setU32BE(dv, p, it.extentLength)
p += 4
for (const extent of extents) {
p += indexSize
setU32BE(dv, p, extent.extentOffset)
p += 4
setU32BE(dv, p, extent.extentLength)
p += 4
}
}
return buf
}
@@ -231,7 +312,13 @@ interface BuildAvifOpts {
ftypBrand?: string
omitMeta?: boolean
omitIloc?: boolean
iinfVersion?: number
infeVersion?: number
ilocVersion?: number
ilocBaseOffsetSize?: number
ilocIndexSize?: number
ilocExtents?: Array<{ extentOffset: number; extentLength: number }>
rawExifData?: Uint8Array
}
const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
@@ -242,7 +329,13 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
ftypBrand = 'avif',
omitMeta = false,
omitIloc = false,
infeVersion = 2
iinfVersion = 0,
infeVersion = 2,
ilocVersion = 0,
ilocBaseOffsetSize = 0,
ilocIndexSize = 0,
ilocExtents,
rawExifData
} = opts
const ftyp = buildFtypBox(ftypBrand)
@@ -250,19 +343,43 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
return ftyp.slice().buffer as ArrayBuffer
}
const exifData = buildExifBlob(exifEntries, endian)
const exifData = rawExifData ?? buildExifBlob(exifEntries, endian)
const infe = buildInfeBox(1, itemType, infeVersion)
const iinf = buildIinfBox([infe])
const iinf = buildIinfBox([infe], iinfVersion)
const realIloc = buildIlocBox([
{ itemId: 1, extentOffset: 0, extentLength: exifData.length }
])
const realIloc = buildIlocBox(
[
{
itemId: 1,
extentOffset: 0,
extentLength: exifData.length,
extents: ilocExtents
}
],
{
version: ilocVersion,
baseOffsetSize: ilocBaseOffsetSize,
indexSize: ilocIndexSize
}
)
const metaSize = 8 + 4 + iinf.length + (omitIloc ? 0 : realIloc.length)
const exifOffset = ftyp.length + metaSize
const finalIloc = buildIlocBox([
{ itemId: 1, extentOffset: exifOffset, extentLength: exifData.length }
])
const finalIloc = buildIlocBox(
[
{
itemId: 1,
extentOffset: exifOffset,
extentLength: exifData.length,
extents: ilocExtents
}
],
{
version: ilocVersion,
baseOffsetSize: ilocBaseOffsetSize,
indexSize: ilocIndexSize
}
)
const finalInner = omitIloc ? [iinf] : [iinf, finalIloc]
const meta = buildMetaBox(finalInner)
@@ -319,6 +436,52 @@ describe('getFromAvifFile', () => {
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
})
it('extracts EXIF metadata from versioned item info and location boxes', async () => {
const workflow = '{"versioned":true}'
const file = fileFromBuffer(
buildAvifFile({
exifEntries: [`workflow:${workflow}`],
iinfVersion: 1,
infeVersion: 3,
ilocVersion: 2,
ilocBaseOffsetSize: 4,
ilocIndexSize: 4
})
)
const result = await getFromAvifFile(file)
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
})
it('returns {} when the Exif item has no extents', async () => {
const file = fileFromBuffer(
buildAvifFile({
exifEntries: ['workflow:{}'],
ilocExtents: []
})
)
const result = await getFromAvifFile(file)
expect(result).toEqual({})
})
it('returns {} when the Exif payload has no TIFF header', async () => {
const file = fileFromBuffer(
buildAvifFile({
rawExifData: new TextEncoder().encode('not tiff data')
})
)
const result = await getFromAvifFile(file)
expect(result).toEqual({})
expect(console.log).toHaveBeenCalledWith(
'Warning: TIFF header not found in EXIF data.'
)
})
it('returns {} when AVIF major brand is not "avif"', async () => {
const file = fileFromBuffer(
buildAvifFile({ exifEntries: ['workflow:{}'], ftypBrand: 'heic' })

View File

@@ -7,7 +7,8 @@ import {
EXPECTED_PROMPT_NAN_COERCED,
EXPECTED_WORKFLOW,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromWebmFile } from './ebml'
@@ -16,6 +17,56 @@ const nanFixturePath = path.resolve(
__dirname,
'__fixtures__/with_nan_metadata.webm'
)
const WEBM_SIGNATURE = new Uint8Array([0x1a, 0x45, 0xdf, 0xa3])
const SIMPLE_TAG = new Uint8Array([0x67, 0xc8])
const TAG_NAME = new Uint8Array([0x45, 0xa3])
const TAG_VALUE = new Uint8Array([0x44, 0x87])
const encoder = new TextEncoder()
function concatBytes(...chunks: Uint8Array[]) {
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
const result = new Uint8Array(totalLength)
let offset = 0
for (const chunk of chunks) {
result.set(chunk, offset)
offset += chunk.length
}
return result
}
function bytes(...values: number[]) {
return new Uint8Array(values)
}
function vint(value: number) {
return bytes(0x80 | value)
}
function element(id: Uint8Array, value: string) {
const encoded = encoder.encode(value)
return concatBytes(id, vint(encoded.length), encoded)
}
function simpleTag(name: string, value: string, useTwoByteSize = false) {
const payload = concatBytes(
element(TAG_NAME, name),
element(TAG_VALUE, value)
)
const size = useTwoByteSize
? bytes(0x40, payload.length)
: vint(payload.length)
return concatBytes(SIMPLE_TAG, size, payload)
}
async function readWebm(bytes: Uint8Array) {
return getFromWebmFile(
new File([bytes as Uint8Array<ArrayBuffer>], 'test.webm', {
type: 'video/webm'
})
)
}
describe('WebM/EBML metadata', () => {
it('extracts workflow and prompt from EBML SimpleTag elements', async () => {
@@ -46,6 +97,89 @@ describe('WebM/EBML metadata', () => {
expect(result).toEqual({})
})
it('extracts plain string tags and trims null-terminated values', async () => {
const result = await readWebm(
concatBytes(
WEBM_SIGNATURE,
simpleTag('\0Comment', ' hello \0ignored', true)
)
)
expect(result.comment).toBe('hello')
})
it('ignores prompt tags whose value is not complete JSON', async () => {
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', '{not-json'))
)
expect(result.prompt).toBeUndefined()
})
it('ignores prompt tags whose value has no JSON object', async () => {
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', 'not-json'))
)
expect(result.prompt).toBeUndefined()
})
it('parses the first complete JSON object from a prompt tag', async () => {
const result = await readWebm(
concatBytes(
WEBM_SIGNATURE,
simpleTag('PROMPT', 'prefix {"outer":{"inner":1}} trailing')
)
)
expect(result.prompt).toEqual({ outer: { inner: 1 } })
})
it('ignores tags whose name has no readable text', async () => {
const payload = concatBytes(
concatBytes(TAG_NAME, vint(2), bytes(0, 1)),
element(TAG_VALUE, 'value')
)
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
)
expect(result).toEqual({})
})
it('ignores tag elements with zero-sized names', async () => {
const payload = concatBytes(TAG_NAME, vint(0), element(TAG_VALUE, 'value'))
const result = await readWebm(
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
)
expect(result).toEqual({})
})
it('ignores malformed SimpleTag encodings', async () => {
const nameOnly = element(TAG_NAME, 'comment')
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG))
).resolves.toEqual({})
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x00)))
).resolves.toEqual({})
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x7f, 0xff)))
).resolves.toEqual({})
await expect(
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(10), TAG_NAME))
).resolves.toEqual({})
await expect(
readWebm(
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(nameOnly.length), nameOnly)
)
).resolves.toEqual({})
})
describe('FileReader failure modes', () => {
afterEach(() => vi.restoreAllMocks())
@@ -60,5 +194,10 @@ describe('WebM/EBML metadata', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getFromWebmFile(file)).toEqual({})
})
it('resolves empty when the FileReader load has no result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getFromWebmFile(file)).toEqual({})
})
})
})

View File

@@ -1,3 +1,4 @@
import { fromAny } from '@total-typescript/shoehorn'
import { afterEach, describe, expect, it, vi } from 'vitest'
import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
@@ -5,7 +6,8 @@ import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
import {
EXPECTED_PROMPT_NAN_COERCED,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getGltfBinaryMetadata } from './gltf'
@@ -16,11 +18,15 @@ describe('GLTF binary metadata parser', () => {
return { header, headerView }
}
const createJSONChunk = (jsonData: ArrayBuffer) => {
const createJSONChunk = (
jsonData: ArrayBuffer,
chunkType = ASCII.JSON,
chunkLength = jsonData.byteLength
) => {
const chunkHeader = new ArrayBuffer(GltfSizeBytes.CHUNK_HEADER)
const chunkView = new DataView(chunkHeader)
chunkView.setUint32(0, jsonData.byteLength, true)
chunkView.setUint32(4, ASCII.JSON, true)
chunkView.setUint32(0, chunkLength, true)
chunkView.setUint32(4, chunkType, true)
return chunkHeader
}
@@ -52,13 +58,27 @@ describe('GLTF binary metadata parser', () => {
// Builds a GLB whose JSON chunk is the literal text passed in - used to
// embed Python generated bare NaN/Infinity tokens that JSON.stringify
// would otherwise coerce to null.
function createMockGltfFileFromText(jsonText: string): File {
interface MockGltfFileOptions {
chunkLength?: number
chunkType?: number
magicNumber?: number
}
function createMockGltfFileFromText(
jsonText: string,
{
chunkLength,
chunkType = ASCII.JSON,
magicNumber = ASCII.GLTF
}: MockGltfFileOptions = {}
): File {
const jsonData = new TextEncoder().encode(jsonText)
const { header, headerView } = createGLTFFileStructure()
setHeaders(headerView, jsonData.buffer)
setTypeHeader(headerView, magicNumber)
const chunkHeader = createJSONChunk(jsonData.buffer)
const chunkHeader = createJSONChunk(jsonData.buffer, chunkType, chunkLength)
const fileContent = new Uint8Array(
header.byteLength + chunkHeader.byteLength + jsonData.byteLength
@@ -130,7 +150,9 @@ describe('GLTF binary metadata parser', () => {
expect(metadata).toBeDefined()
expect(metadata.prompt).toBeDefined()
const prompt = metadata.prompt as Record<string, any>
const prompt: {
node1: { class_type: string; inputs: { seed: number } }
} = fromAny(metadata.prompt)
expect(prompt.node1.class_type).toBe('TestNode')
expect(prompt.node1.inputs.seed).toBe(123456)
})
@@ -179,6 +201,74 @@ describe('GLTF binary metadata parser', () => {
expect(metadata).toEqual({})
})
it('returns empty when the GLB magic number is invalid', async () => {
const mockFile = createMockGltfFileFromText('{}', { magicNumber: 0 })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the GLB is missing the first chunk header', async () => {
const { header, headerView } = createGLTFFileStructure()
setTypeHeader(headerView, ASCII.GLTF)
setVersionHeader(headerView, 2)
setTotalLengthHeader(headerView, GltfSizeBytes.HEADER)
const mockFile = new File([header], 'header-only.glb')
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the first chunk is not JSON', async () => {
const mockFile = createMockGltfFileFromText('{}', { chunkType: 0 })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the declared JSON chunk exceeds the buffer', async () => {
const mockFile = createMockGltfFileFromText('{}', { chunkLength: 1024 })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when the JSON chunk cannot be parsed', async () => {
const mockFile = createMockGltfFileFromText('{not json')
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('returns empty when asset extras are missing', async () => {
const mockFile = createMockGltfFile({ asset: { version: '2.0' } })
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
it('skips string metadata values that are not valid JSON', async () => {
const mockFile = createMockGltfFile({
asset: {
version: '2.0',
extras: {
prompt: '{not json',
workflow: '{not json'
}
}
})
const metadata = await getGltfBinaryMetadata(mockFile)
expect(metadata).toEqual({})
})
describe('FileReader failure modes', () => {
afterEach(() => vi.restoreAllMocks())
@@ -193,5 +283,15 @@ describe('GLTF binary metadata parser', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getGltfBinaryMetadata(file)).toEqual({})
})
it('resolves empty when the FileReader load result is missing', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getGltfBinaryMetadata(file)).toEqual({})
})
it('resolves empty when the FileReader result is not a buffer', async () => {
mockFileReaderResult('readAsArrayBuffer', 'not a buffer')
expect(await getGltfBinaryMetadata(file)).toEqual({})
})
})
})

View File

@@ -7,7 +7,8 @@ import {
EXPECTED_PROMPT_NAN_COERCED,
EXPECTED_WORKFLOW,
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromIsobmffFile } from './isobmff'
@@ -16,6 +17,82 @@ const nanFixturePath = path.resolve(
__dirname,
'__fixtures__/with_nan_metadata.mp4'
)
const encoder = new TextEncoder()
function uint32(value: number) {
return new Uint8Array([
(value >>> 24) & 0xff,
(value >>> 16) & 0xff,
(value >>> 8) & 0xff,
value & 0xff
])
}
function concatBytes(...chunks: Uint8Array[]) {
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
const result = new Uint8Array(totalLength)
let offset = 0
for (const chunk of chunks) {
result.set(chunk, offset)
offset += chunk.length
}
return result
}
function box(type: string, payload = new Uint8Array(), size?: number) {
return concatBytes(
uint32(size ?? 8 + payload.length),
encoder.encode(type),
payload
)
}
function keyEntry(name: string) {
const encoded = encoder.encode(name)
return concatBytes(
uint32(8 + encoded.length),
encoder.encode('mdta'),
encoded
)
}
function keysBox(names: string[]) {
return box(
'keys',
concatBytes(uint32(0), uint32(names.length), ...names.map(keyEntry))
)
}
function dataBox(value: string | Uint8Array) {
const payload = typeof value === 'string' ? encoder.encode(value) : value
return box('data', concatBytes(uint32(0), uint32(0), payload))
}
function ilstItem(index: number, payload: Uint8Array) {
return concatBytes(uint32(8 + payload.length), uint32(index), payload)
}
function ilstBox(...items: Uint8Array[]) {
return box('ilst', concatBytes(...items))
}
function metaBox(...children: Uint8Array[]) {
return box('meta', concatBytes(uint32(0), ...children))
}
function udtaWithMeta(...children: Uint8Array[]) {
return box('udta', metaBox(...children))
}
async function readMp4(bytes: Uint8Array) {
return getFromIsobmffFile(
new File([bytes as Uint8Array<ArrayBuffer>], 'test.mp4', {
type: 'video/mp4'
})
)
}
describe('ISOBMFF (MP4) metadata', () => {
it('extracts workflow and prompt from QuickTime keys/ilst boxes', async () => {
@@ -48,6 +125,102 @@ describe('ISOBMFF (MP4) metadata', () => {
expect(result).toEqual({})
})
it('extracts metadata from udta nested inside moov', async () => {
const bytes = box(
'moov',
udtaWithMeta(
keysBox(['WORKFLOW']),
ilstBox(ilstItem(1, dataBox('xxxx{"nodes":[]}')))
)
)
const result = await readMp4(bytes)
expect(result.workflow).toEqual({ nodes: [] })
})
it('returns empty when a top-level box declares an impossible size', async () => {
const result = await readMp4(box('free', new Uint8Array([1, 2]), 100))
expect(result).toEqual({})
})
it('returns empty when the keys box cannot provide entries', async () => {
const tooShortKeys = box('keys', uint32(0))
const missingKeyEntry = box(
'keys',
concatBytes(uint32(0), uint32(1), uint32(8))
)
const malformedKey = box(
'keys',
concatBytes(uint32(0), uint32(1), uint32(7), encoder.encode('bad!'))
)
const oversizedKey = box(
'keys',
concatBytes(uint32(0), uint32(1), uint32(100), encoder.encode('bad!'))
)
await expect(readMp4(udtaWithMeta(tooShortKeys))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(missingKeyEntry))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(malformedKey))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(oversizedKey))).resolves.toEqual({})
})
it('ignores item entries whose key is unknown or unsupported', async () => {
const unknownIndex = udtaWithMeta(
keysBox(['PROMPT']),
ilstBox(ilstItem(2, dataBox('{"1":{}}')))
)
const unsupportedKey = udtaWithMeta(
keysBox(['DESCRIPTION']),
ilstBox(ilstItem(1, dataBox('{"ignored":true}')))
)
await expect(readMp4(unknownIndex)).resolves.toEqual({})
await expect(readMp4(unsupportedKey)).resolves.toEqual({})
})
it('ignores metadata items without readable JSON data', async () => {
const shortDataBox = box('data')
const noJson = dataBox('not-json')
const invalidJson = dataBox('prefix {not-json')
const noDataBox = box('free', new Uint8Array([1]))
const invalidItems = ilstBox(
concatBytes(uint32(8), uint32(1)),
concatBytes(uint32(100), uint32(1), invalidJson)
)
await expect(
readMp4(
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, shortDataBox)))
)
).resolves.toEqual({})
await expect(
readMp4(udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noJson))))
).resolves.toEqual({})
await expect(
readMp4(
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, invalidJson)))
)
).resolves.toEqual({})
await expect(
readMp4(
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noDataBox)))
)
).resolves.toEqual({})
await expect(
readMp4(udtaWithMeta(keysBox(['PROMPT']), invalidItems))
).resolves.toEqual({})
})
it('returns empty when required metadata boxes are absent', async () => {
await expect(readMp4(box('udta', box('free')))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(ilstBox()))).resolves.toEqual({})
await expect(readMp4(udtaWithMeta(keysBox(['PROMPT'])))).resolves.toEqual(
{}
)
})
describe('FileReader failure modes', () => {
afterEach(() => vi.restoreAllMocks())
@@ -63,5 +236,10 @@ describe('ISOBMFF (MP4) metadata', () => {
mockFileReaderAbort('readAsArrayBuffer')
expect(await getFromIsobmffFile(file)).toEqual({})
})
it('resolves empty when the FileReader load has no result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
expect(await getFromIsobmffFile(file)).toEqual({})
})
})
})

View File

@@ -0,0 +1,127 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { getFromWebmFile } from '@/scripts/metadata/ebml'
import { getGltfBinaryMetadata } from '@/scripts/metadata/gltf'
import { getFromIsobmffFile } from '@/scripts/metadata/isobmff'
import { getDataFromJSON } from '@/scripts/metadata/json'
import { getMp3Metadata } from '@/scripts/metadata/mp3'
import { getOggMetadata } from '@/scripts/metadata/ogg'
import { getWorkflowDataFromFile } from '@/scripts/metadata/parser'
import { getSvgMetadata } from '@/scripts/metadata/svg'
import {
getAvifMetadata,
getFlacMetadata,
getLatentMetadata,
getPngMetadata,
getWebpMetadata
} from '@/scripts/pnginfo'
vi.mock('@/scripts/metadata/ebml', () => ({ getFromWebmFile: vi.fn() }))
vi.mock('@/scripts/metadata/gltf', () => ({ getGltfBinaryMetadata: vi.fn() }))
vi.mock('@/scripts/metadata/isobmff', () => ({ getFromIsobmffFile: vi.fn() }))
vi.mock('@/scripts/metadata/json', () => ({ getDataFromJSON: vi.fn() }))
vi.mock('@/scripts/metadata/mp3', () => ({ getMp3Metadata: vi.fn() }))
vi.mock('@/scripts/metadata/ogg', () => ({ getOggMetadata: vi.fn() }))
vi.mock('@/scripts/metadata/svg', () => ({ getSvgMetadata: vi.fn() }))
vi.mock('@/scripts/pnginfo', () => ({
getAvifMetadata: vi.fn(),
getFlacMetadata: vi.fn(),
getLatentMetadata: vi.fn(),
getPngMetadata: vi.fn(),
getWebpMetadata: vi.fn()
}))
function file(type: string, name = 'file') {
return new File(['data'], name, { type })
}
beforeEach(() => {
vi.clearAllMocks()
})
describe('getWorkflowDataFromFile', () => {
it('routes png/avif/mp3/ogg/webm to their parsers and returns the result', async () => {
vi.mocked(getPngMetadata).mockResolvedValue({ a: 1 } as never)
expect(await getWorkflowDataFromFile(file('image/png'))).toEqual({ a: 1 })
expect(getPngMetadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('image/avif'))
expect(getAvifMetadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('audio/mpeg'))
expect(getMp3Metadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('audio/ogg'))
expect(getOggMetadata).toHaveBeenCalled()
await getWorkflowDataFromFile(file('video/webm'))
expect(getFromWebmFile).toHaveBeenCalled()
})
it('extracts workflow/prompt from webp, preferring lowercase keys', async () => {
vi.mocked(getWebpMetadata).mockResolvedValue({
workflow: 'wf',
prompt: 'pr'
} as never)
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
workflow: 'wf',
prompt: 'pr'
})
})
it('falls back to capitalized webp keys when lowercase are absent', async () => {
vi.mocked(getWebpMetadata).mockResolvedValue({
Workflow: 'WF',
Prompt: 'PR'
} as never)
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
workflow: 'WF',
prompt: 'PR'
})
})
it('handles both flac mime types and extracts workflow/prompt', async () => {
vi.mocked(getFlacMetadata).mockResolvedValue({ workflow: 'w' } as never)
expect(await getWorkflowDataFromFile(file('audio/flac'))).toEqual({
workflow: 'w',
prompt: undefined
})
expect(await getWorkflowDataFromFile(file('audio/x-flac'))).toEqual({
workflow: 'w',
prompt: undefined
})
})
it('routes isobmff by mime type and by file extension', async () => {
await getWorkflowDataFromFile(file('video/mp4'))
await getWorkflowDataFromFile(file('', 'clip.mov'))
await getWorkflowDataFromFile(file('', 'clip.m4v'))
expect(getFromIsobmffFile).toHaveBeenCalledTimes(3)
})
it('routes svg and gltf by mime type or extension', async () => {
await getWorkflowDataFromFile(file('image/svg+xml'))
await getWorkflowDataFromFile(file('', 'icon.svg'))
expect(getSvgMetadata).toHaveBeenCalledTimes(2)
await getWorkflowDataFromFile(file('model/gltf-binary'))
await getWorkflowDataFromFile(file('', 'model.glb'))
expect(getGltfBinaryMetadata).toHaveBeenCalledTimes(2)
})
it('routes latent/safetensors and json by extension or mime type', async () => {
await getWorkflowDataFromFile(file('', 'x.latent'))
await getWorkflowDataFromFile(file('', 'x.safetensors'))
expect(getLatentMetadata).toHaveBeenCalledTimes(2)
await getWorkflowDataFromFile(file('application/json'))
await getWorkflowDataFromFile(file('', 'x.json'))
expect(getDataFromJSON).toHaveBeenCalledTimes(2)
})
it('returns undefined for an unrecognized file', async () => {
expect(
await getWorkflowDataFromFile(file('application/zip', 'a.zip'))
).toBe(undefined)
})
})

View File

@@ -2,7 +2,8 @@ import { afterEach, describe, expect, it, vi } from 'vitest'
import {
mockFileReaderAbort,
mockFileReaderError
mockFileReaderError,
mockFileReaderResult
} from './__fixtures__/helpers'
import { getFromPngBuffer, getFromPngFile } from './png'
@@ -191,6 +192,27 @@ describe('getFromPngBuffer', () => {
const result = await getFromPngBuffer(buffer)
expect(result['workflow']).toBe(workflow)
})
it('logs error and skips compressed iTXt with invalid deflate data', async () => {
vi.spyOn(console, 'error').mockImplementation(() => {})
const buffer = createPngWithChunk(
'iTXt',
'workflow',
new Uint8Array([1, 2, 3]),
{
compressionFlag: 1,
compressionMethod: 0
}
)
const result = await getFromPngBuffer(buffer)
expect(result['workflow']).toBeUndefined()
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to decompress iTXt chunk "workflow"'),
expect.anything()
)
})
})
describe('getFromPngFile', () => {
@@ -228,5 +250,12 @@ describe('getFromPngFile', () => {
mockFileReaderAbort('readAsArrayBuffer')
await expect(getFromPngFile(file)).rejects.toThrow('FileReader aborted')
})
it('rejects when the FileReader load has no ArrayBuffer result', async () => {
mockFileReaderResult('readAsArrayBuffer', null)
await expect(getFromPngFile(file)).rejects.toThrow(
'Failed to read file as ArrayBuffer'
)
})
})
})

View File

@@ -1,13 +1,19 @@
import fs from 'fs'
import path from 'path'
import { afterEach, describe, expect, it, vi } from 'vitest'
import { fromPartial } from '@total-typescript/shoehorn'
import type { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
import { api } from '@/scripts/api'
import { getFromAvifFile } from './metadata/avif'
import { getFromFlacFile } from './metadata/flac'
import { getFromPngFile } from './metadata/png'
import {
getAvifMetadata,
getFlacMetadata,
importA1111,
getLatentMetadata,
getPngMetadata,
getWebpMetadata
@@ -56,6 +62,20 @@ function encodeAsciiIfd(entries: AsciiIfdEntry[]): Uint8Array {
return buf
}
function encodeNonAsciiIfdEntry(tag: number): Uint8Array {
const buf = new Uint8Array(22)
const dv = new DataView(buf.buffer)
buf.set([0x49, 0x49], 0)
dv.setUint16(2, 0x002a, true)
dv.setUint32(4, 8, true)
dv.setUint16(8, 1, true)
dv.setUint16(10, tag, true)
dv.setUint16(12, 3, true)
dv.setUint32(14, 1, true)
dv.setUint32(18, 123, true)
return buf
}
type WebpChunk = { type: string; payload: Uint8Array }
function wrapInWebp(chunks: WebpChunk[]): File {
@@ -157,6 +177,16 @@ describe('getWebpMetadata', () => {
expect(metadata).toEqual({ workflow: '{"a":1}' })
})
it('ignores EXIF entries that are not ASCII strings', async () => {
const file = wrapInWebp([
{ type: 'EXIF', payload: encodeNonAsciiIfdEntry(270) }
])
const metadata = await getWebpMetadata(file)
expect(metadata).toEqual({})
})
})
describe('getLatentMetadata', () => {
@@ -234,3 +264,313 @@ describe('format-specific metadata wrappers', () => {
expect(result).toEqual({ workflow: '{"avif":1}' })
})
})
describe('importA1111', () => {
function widget(
name: string,
options: string[] = []
): IBaseWidget & { value?: string | number } {
return fromPartial<IBaseWidget & { value?: string | number }>({
name,
options: { values: options },
value: undefined
})
}
function createNode(type: string): LGraphNode {
const widgetsByType: Record<string, IBaseWidget[]> = {
CheckpointLoaderSimple: [widget('ckpt_name', ['sd15.safetensors'])],
CLIPSetLastLayer: [widget('stop_at_clip_layer')],
CLIPTextEncode: [widget('text')],
EmptyLatentImage: [widget('width'), widget('height')],
ImageScale: [widget('width'), widget('height')],
ImageUpscaleWithModel: [],
KSampler: [
widget('cfg'),
widget('sampler_name', ['euler_a', 'dpmpp_2m']),
widget('scheduler', ['normal', 'karras']),
widget('seed'),
widget('steps'),
widget('denoise')
],
LatentUpscale: [
widget('upscale_method', ['nearest-exact']),
widget('width'),
widget('height')
],
LoraLoader: [
widget('lora_name', ['foo.safetensors']),
widget('strength_model'),
widget('strength_clip')
],
UpscaleModelLoader: [widget('model_name', ['ESRGAN'])],
VAEEncodeTiled: [],
VAEDecodeTiled: [],
SaveImage: [],
VAEDecode: []
}
return {
type,
widgets: widgetsByType[type] ?? [],
connect: vi.fn()
} as unknown as LGraphNode
}
function createGraph(): LGraph {
return fromPartial<LGraph>({
add: vi.fn(),
arrange: vi.fn(),
clear: vi.fn()
})
}
function findWidget(node: LGraphNode, name: string) {
return node.widgets?.find((widget) => widget.name === name)
}
it('ignores text without parsed generation settings', async () => {
const graph = createGraph()
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
await importA1111(graph, 'positive prompt only')
await importA1111(graph, 'positive prompt\nSteps:\n')
expect(graph.clear).not.toHaveBeenCalled()
})
it('ignores text without a negative prompt section', async () => {
const graph = createGraph()
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
await importA1111(
graph,
['positive prompt only', 'Steps: 20, Sampler: Euler a'].join('\n')
)
expect(graph.clear).not.toHaveBeenCalled()
})
it('stops when a required base node cannot be created', async () => {
const graph = createGraph()
vi.spyOn(console, 'error').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) =>
type === 'KSampler' ? null : createNode(type)
)
await importA1111(
graph,
[
'prompt',
'Negative prompt: blurry',
'Steps: 20, Sampler: Euler a, Size: 512x512'
].join('\n')
)
expect(graph.clear).not.toHaveBeenCalled()
expect(console.error).toHaveBeenCalledWith(
'Failed to create required nodes for A1111 import'
)
})
it('builds a basic graph from A1111 parameters', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue(['easynegative'])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'<lora:foo:0.7> portrait easynegative',
'Negative prompt: blurry <lora:bad:not-number>',
'Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 42, Size: 512x512, Model: sd15, Clip skip: 2, Model hash: ignored'
].join('\n')
)
const checkpoint = nodes.find(
(node) => node.type === 'CheckpointLoaderSimple'
)
const clipSkip = nodes.find((node) => node.type === 'CLIPSetLastLayer')
const sampler = nodes.find((node) => node.type === 'KSampler')
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
const lora = nodes.find((node) => node.type === 'LoraLoader')
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
expect(graph.clear).toHaveBeenCalledOnce()
expect(graph.arrange).toHaveBeenCalledOnce()
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('sd15.safetensors')
expect(findWidget(clipSkip!, 'stop_at_clip_layer')?.value).toBe(-2)
expect(findWidget(sampler!, 'cfg')?.value).toBe(7)
expect(findWidget(sampler!, 'sampler_name')?.value).toBe('euler_a')
expect(findWidget(sampler!, 'scheduler')?.value).toBe('normal')
expect(findWidget(sampler!, 'seed')?.value).toBe(42)
expect(findWidget(sampler!, 'steps')?.value).toBe(20)
expect(findWidget(image!, 'width')?.value).toBe(512)
expect(findWidget(image!, 'height')?.value).toBe(512)
expect(findWidget(lora!, 'lora_name')?.value).toBe('foo.safetensors')
expect(findWidget(lora!, 'strength_model')?.value).toBe(0.7)
expect(findWidget(lora!, 'strength_clip')?.value).toBe(0.7)
expect(findWidget(textNodes[0], 'text')?.value).toBe(
' portrait embedding:easynegative'
)
expect(findWidget(textNodes[1], 'text')?.value).toBe('blurry ')
})
it('keeps unknown option-prefix values and logs the mismatch', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 20, Sampler: Unknown Sampler, CFG scale: 7, Seed: 42, Size: 512x512, Model: unknown-model'
].join('\n')
)
const checkpoint = nodes.find(
(node) => node.type === 'CheckpointLoaderSimple'
)
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('unknown-model')
expect(console.warn).toHaveBeenCalledWith(
"Unknown value 'unknown-model' for widget 'ckpt_name'",
checkpoint
)
})
it('skips missing LoraLoader nodes while keeping prompt text cleaned', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
if (type === 'LoraLoader') return null
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'<lora:missing:0.5> portrait',
'Negative prompt: blurry',
'Steps: 20, Sampler: Euler a, Size: 512x512'
].join('\n')
)
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
expect(findWidget(textNodes[0], 'text')?.value).toBe(' portrait')
})
it('returns from latent hires setup when LatentUpscale cannot be created', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
if (type === 'LatentUpscale') return null
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 8, Sampler: Euler a, Size: 512x512, Hires upscale: 2, Hires upscaler: Latent'
].join('\n')
)
expect(nodes.some((node) => node.type === 'KSampler')).toBe(true)
expect(nodes.some((node) => node.type === 'LatentUpscale')).toBe(false)
})
it('builds a latent hires pass with explicit resize and denoise settings', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 12, Sampler: DPM++ 2M Karras, CFG scale: 5, Seed: 1, Size: 513x577, Model: sd15, Hires resize: 1025x1089, Hires steps: 4, Hires upscaler: Latent (nearest-exact), Denoising strength: 0.35'
].join('\n')
)
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
const latentUpscale = nodes.find((node) => node.type === 'LatentUpscale')
const samplers = nodes.filter((node) => node.type === 'KSampler')
expect(findWidget(image!, 'width')?.value).toBe(576)
expect(findWidget(image!, 'height')?.value).toBe(640)
expect(findWidget(latentUpscale!, 'upscale_method')?.value).toBe(
'nearest-exact'
)
expect(findWidget(latentUpscale!, 'width')?.value).toBe(1088)
expect(findWidget(latentUpscale!, 'height')?.value).toBe(1152)
expect(findWidget(samplers[0], 'scheduler')?.value).toBe('karras')
expect(findWidget(samplers[0], 'sampler_name')?.value).toBe('dpmpp_2m')
expect(findWidget(samplers[1], 'steps')?.value).toBe(4)
expect(findWidget(samplers[1], 'cfg')?.value).toBe(5)
expect(findWidget(samplers[1], 'scheduler')?.value).toBe('karras')
expect(findWidget(samplers[1], 'sampler_name')?.value).toBe('dpmpp_2m')
expect(findWidget(samplers[1], 'denoise')?.value).toBe(0.35)
})
it('builds an image upscaler hires pass with fallback steps and denoise', async () => {
const graph = createGraph()
const nodes: LGraphNode[] = []
vi.spyOn(console, 'warn').mockImplementation(() => {})
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
const node = createNode(type)
nodes.push(node)
return node
})
await importA1111(
graph,
[
'portrait',
'Negative prompt: blurry',
'Steps: 8, Sampler: Euler a, CFG scale: 6, Seed: 2, Size: 512x512, Model: sd15, Hires upscale: 1.5, Hires upscaler: ESRGAN'
].join('\n')
)
const upscaleLoader = nodes.find(
(node) => node.type === 'UpscaleModelLoader'
)
const imageScale = nodes.find((node) => node.type === 'ImageScale')
const samplers = nodes.filter((node) => node.type === 'KSampler')
expect(findWidget(upscaleLoader!, 'model_name')?.value).toBe('ESRGAN')
expect(findWidget(imageScale!, 'width')?.value).toBe(768)
expect(findWidget(imageScale!, 'height')?.value).toBe(768)
expect(findWidget(samplers[1], 'steps')?.value).toBe(8)
expect(findWidget(samplers[1], 'denoise')?.value).toBe(1)
})
})

450
src/scripts/ui.test.ts Normal file
View File

@@ -0,0 +1,450 @@
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { mockSettingStore } = vi.hoisted(() => ({
mockSettingStore: {
get: vi.fn()
}
}))
vi.mock('@/composables/useRunButtonTelemetry', () => ({
useRunButtonTelemetry: () => ({ trackRunButton: vi.fn() })
}))
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => ({
extractWorkflow: vi.fn()
}))
vi.mock('@/platform/settings/composables/useSettingsDialog', () => ({
useSettingsDialog: () => ({ show: vi.fn() })
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => mockSettingStore
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackWorkflowExecution: vi.fn() })
}))
vi.mock('@/services/litegraphService', () => ({
useLitegraphService: () => ({ resetView: vi.fn() })
}))
vi.mock('@/stores/commandStore', () => ({
useCommandStore: () => ({ execute: vi.fn() })
}))
vi.mock('@/stores/workspaceStore', () => ({
useWorkspaceStore: () => ({ focusMode: false })
}))
vi.mock('./api', () => ({
api: {
addEventListener: vi.fn(),
clearItems: vi.fn(),
deleteItem: vi.fn(),
dispatchCustomEvent: vi.fn(),
getHistory: vi.fn(),
getJobDetail: vi.fn(),
getQueue: vi.fn(),
interrupt: vi.fn()
}
}))
vi.mock('./app', () => ({
app: {
clean: vi.fn(),
handleFile: vi.fn(),
loadGraphData: vi.fn(),
openClipspace: vi.fn(),
queuePrompt: vi.fn(),
refreshComboInNodes: vi.fn(() => Promise.resolve())
},
ComfyApp: class ComfyApp {}
}))
vi.mock('./ui/dialog', () => ({
ComfyDialog: class ComfyDialog {}
}))
vi.mock('./ui/settings', () => ({
ComfySettingsDialog: class ComfySettingsDialog {}
}))
vi.mock('./ui/toggleSwitch', () => ({
toggleSwitch: () => document.createElement('div')
}))
import { extractWorkflow } from '@/platform/remote/comfyui/jobs/fetchJobs'
import { api } from './api'
import { app } from './app'
import { $el, ComfyUI } from './ui'
beforeEach(() => {
document.body.replaceChildren()
localStorage.clear()
vi.clearAllMocks()
mockSettingStore.get.mockReturnValue(undefined)
Object.assign(app, {
lastExecutionError: undefined,
nodeOutputs: undefined
})
})
async function click(button: HTMLButtonElement) {
const handler = button.onclick
expect(handler).toBeTypeOf('function')
await Promise.resolve(handler?.call(button, fromAny(new MouseEvent('click'))))
}
function buttonByText(root: ParentNode, text: string): HTMLButtonElement {
const button = [...root.querySelectorAll('button')].find(
(candidate) => candidate.textContent === text
)
if (!button) throw new Error(`Missing button: ${text}`)
return button
}
describe('$el', () => {
it('creates elements with classes, children, props, and callbacks', () => {
const parent = document.createElement('section')
const child = document.createElement('span')
const callback = vi.fn()
const element = $el(
'label.primary.secondary',
{
parent,
$: callback,
dataset: { role: 'name' },
for: 'target-input',
style: { display: 'block' },
title: 'Label'
},
child
)
expect(element.tagName).toBe('LABEL')
expect(element.classList.contains('primary')).toBe(true)
expect(element.classList.contains('secondary')).toBe(true)
expect(element.dataset.role).toBe('name')
expect(element.getAttribute('for')).toBe('target-input')
expect(element.style.display).toBe('block')
expect(element.title).toBe('Label')
expect(element.firstElementChild).toBe(child)
expect(parent.firstElementChild).toBe(element)
expect(callback).toHaveBeenCalledWith(element)
})
it('accepts string and single-element children shorthands', () => {
const textElement = $el('button', 'Run')
const child = document.createElement('strong')
const wrapper = $el('div', child)
expect(textElement.textContent).toBe('Run')
expect(wrapper.firstElementChild).toBe(child)
})
})
describe('ComfyUI legacy menu', () => {
it('loads queue items and runs list actions', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [{ id: 'running', priority: 1 }],
Pending: [{ id: 'pending', priority: 2 }]
} as never)
vi.mocked(api.getJobDetail).mockResolvedValue({
outputs: { node: { images: ['image.png'] } }
} as never)
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
const ui = new ComfyUI(app)
await ui.queue.show()
await click(buttonByText(ui.queue.element, 'Load'))
expect(api.getJobDetail).toHaveBeenCalledWith('running')
expect(extractWorkflow).toHaveBeenCalled()
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
expect(app.nodeOutputs).toEqual({ node: { images: ['image.png'] } })
await click(buttonByText(ui.queue.element, 'Cancel'))
expect(api.interrupt).toHaveBeenCalledWith('running')
await click(buttonByText(ui.queue.element, 'Delete'))
expect(api.deleteItem).toHaveBeenCalledWith('queue', 'pending')
await click(buttonByText(ui.queue.element, 'Clear Queue'))
expect(api.clearItems).toHaveBeenCalledWith('queue')
await click(buttonByText(ui.queue.element, 'Refresh'))
expect(api.getQueue).toHaveBeenCalled()
})
it('skips loading queue items when job details are unavailable', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [{ id: 'running', priority: 1 }],
Pending: []
} as never)
vi.mocked(api.getJobDetail).mockResolvedValue(null as never)
const ui = new ComfyUI(app)
await ui.queue.show()
await click(buttonByText(ui.queue.element, 'Load'))
expect(extractWorkflow).not.toHaveBeenCalled()
expect(app.loadGraphData).not.toHaveBeenCalled()
})
it('loads queue item workflows without outputs', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [{ id: 'running', priority: 1 }],
Pending: []
} as never)
vi.mocked(api.getJobDetail).mockResolvedValue({ id: 'running' } as never)
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
const ui = new ComfyUI(app)
await ui.queue.show()
await click(buttonByText(ui.queue.element, 'Load'))
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
expect(app.nodeOutputs).toBeUndefined()
})
it('loads history in reverse order', async () => {
vi.mocked(api.getHistory).mockResolvedValue([
{ id: 'old', priority: 1 },
{ id: 'new', priority: 2 }
] as never)
const ui = new ComfyUI(app)
await ui.history.show()
expect(ui.history.element.textContent).toContain('2: LoadDelete')
expect(ui.history.element.textContent?.indexOf('2:')).toBeLessThan(
ui.history.element.textContent?.indexOf('1:') ?? Number.MAX_SAFE_INTEGER
)
})
it('updates queue status and auto-queues when enabled', () => {
const ui = new ComfyUI(app)
ui.autoQueueEnabled = true
ui.autoQueueMode = 'instant'
ui.lastQueueSize = 1
ui.batchCount = 3
ui.setStatus({ exec_info: { queue_remaining: 0 } })
expect(app.queuePrompt).toHaveBeenCalledWith(0, 3)
expect(ui.queueSize.textContent).toBe('Queue size: 0')
expect(ui.lastQueueSize).toBe(3)
ui.setStatus(null)
expect(ui.queueSize.textContent).toBe('Queue size: ERR')
})
it('does not auto-queue while a prior execution error is present', () => {
const ui = new ComfyUI(app)
ui.autoQueueEnabled = true
ui.autoQueueMode = 'instant'
ui.lastQueueSize = 1
Object.assign(app, { lastExecutionError: new Error('failed') })
ui.setStatus({ exec_info: { queue_remaining: 0 } })
expect(app.queuePrompt).not.toHaveBeenCalled()
expect(ui.lastQueueSize).toBe(0)
})
it('tracks graph changes for change-mode auto queueing', () => {
const ui = new ComfyUI(app)
const graphChanged = vi
.mocked(api.addEventListener)
.mock.calls.find(([eventName]) => eventName === 'graphChanged')?.[1]
if (!graphChanged) throw new Error('Missing graphChanged listener')
ui.autoQueueEnabled = true
ui.autoQueueMode = 'change'
ui.lastQueueSize = 1
graphChanged(fromAny(new CustomEvent('graphChanged')))
expect(ui.graphHasChanged).toBe(true)
ui.lastQueueSize = 0
graphChanged(fromAny(new CustomEvent('graphChanged')))
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
expect(ui.graphHasChanged).toBe(false)
})
it('wires primary menu buttons to app and command actions', async () => {
const ui = new ComfyUI(app)
await click(buttonByText(document, 'Queue Prompt'))
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
await click(buttonByText(document, 'Queue Front'))
expect(app.queuePrompt).toHaveBeenCalledWith(-1, 1)
await click(buttonByText(document, 'Save'))
await click(buttonByText(document, 'Save (API Format)'))
await click(buttonByText(document, 'Refresh'))
await click(buttonByText(document, 'Clipspace'))
await click(buttonByText(document, 'Clear'))
await click(buttonByText(document, 'Load Default'))
await click(buttonByText(document, 'Reset View'))
expect(app.refreshComboInNodes).toHaveBeenCalled()
expect(app.openClipspace).toHaveBeenCalled()
expect(app.clean).toHaveBeenCalled()
expect(app.loadGraphData).toHaveBeenCalledWith()
expect(ui.menuContainer.style.display).toBe('none')
})
it('wires file input and legacy option controls', async () => {
const ui = new ComfyUI(app)
const file = new File(['{}'], 'workflow.json', {
type: 'application/json'
})
const fileInput = document.getElementById(
'comfy-file-input'
) as HTMLInputElement
Object.defineProperty(fileInput, 'files', {
configurable: true,
value: [file]
})
await Promise.resolve(
fileInput.onchange?.call(fileInput, new Event('change'))
)
expect(app.handleFile).toHaveBeenCalledWith(file, 'file_button')
expect(fileInput.value).toBe('')
const range = document.getElementById(
'batchCountInputRange'
) as HTMLInputElement
const number = document.getElementById(
'batchCountInputNumber'
) as HTMLInputElement
range.value = '4'
const extraOptionsCheckbox = document.querySelector(
'label input[type="checkbox"]'
) as HTMLInputElement
extraOptionsCheckbox.checked = true
extraOptionsCheckbox.onchange?.call(
extraOptionsCheckbox,
fromPartial({ srcElement: extraOptionsCheckbox })
)
expect(ui.batchCount).toBe(4)
expect(document.getElementById('extraOptions')?.style.display).toBe('block')
number.value = '7'
number.oninput?.call(number, fromPartial({ target: number }))
expect(range.value).toBe('7')
range.value = '9'
range.oninput?.call(range, fromPartial({ srcElement: range }))
expect(number.value).toBe('9')
const autoQueueCheckbox = document.getElementById(
'autoQueueCheckbox'
) as HTMLInputElement
autoQueueCheckbox.checked = true
autoQueueCheckbox.onchange?.call(
autoQueueCheckbox,
fromPartial({ target: autoQueueCheckbox })
)
expect(ui.autoQueueEnabled).toBe(true)
extraOptionsCheckbox.checked = false
extraOptionsCheckbox.onchange?.call(
extraOptionsCheckbox,
fromPartial({ srcElement: extraOptionsCheckbox })
)
expect(ui.batchCount).toBe(1)
expect(ui.autoQueueEnabled).toBe(false)
expect(document.getElementById('extraOptions')?.style.display).toBe('none')
})
it('toggles queue visibility through the menu button', async () => {
vi.mocked(api.getQueue).mockResolvedValue({
Running: [],
Pending: []
} as never)
const ui = new ComfyUI(app)
await click(buttonByText(document, 'View Queue'))
expect(ui.queue.visible).toBe(true)
await click(buttonByText(document, 'Close'))
expect(ui.queue.visible).toBe(false)
})
it('does not clear or load defaults when confirmation is declined', async () => {
mockSettingStore.get.mockReturnValue(true)
vi.stubGlobal(
'confirm',
vi.fn(() => false)
)
try {
new ComfyUI(app)
await click(buttonByText(document, 'Clear'))
await click(buttonByText(document, 'Load Default'))
expect(app.clean).not.toHaveBeenCalled()
expect(app.loadGraphData).not.toHaveBeenCalled()
} finally {
vi.unstubAllGlobals()
}
})
it('persists manual menu dragging', () => {
Object.defineProperty(document.body, 'clientWidth', {
configurable: true,
value: 1000
})
Object.defineProperty(document.body, 'clientHeight', {
configurable: true,
value: 800
})
const ui = new ComfyUI(app)
ui.menuContainer.style.display = 'block'
Object.defineProperty(ui.menuContainer, 'clientWidth', {
configurable: true,
value: 100
})
Object.defineProperty(ui.menuContainer, 'clientHeight', {
configurable: true,
value: 80
})
Object.defineProperty(ui.menuContainer, 'offsetLeft', {
configurable: true,
get: () => 700
})
Object.defineProperty(ui.menuContainer, 'offsetTop', {
configurable: true,
get: () => 20
})
const handle = ui.menuContainer.querySelector('.drag-handle') as HTMLElement
handle.onmousedown?.(
new MouseEvent('mousedown', { clientX: 10, clientY: 10 })
)
document.onmousemove?.(
new MouseEvent('mousemove', { clientX: 20, clientY: 30 })
)
document.onmouseup?.(new MouseEvent('mouseup'))
expect(ui.menuContainer.classList.contains('comfy-menu-manual-pos')).toBe(
true
)
expect(ui.menuContainer.style.right).toBe('190px')
expect(localStorage.getItem('Comfy.MenuPosition')).toBe(
JSON.stringify({ x: 700, y: 20 })
)
expect(document.onmousemove).toBeNull()
expect(document.onmouseup).toBeNull()
})
})

View File

@@ -0,0 +1,199 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { ComfyApp } from '@/scripts/app'
vi.mock('../../ui', () => ({
$el: (tag: string, props?: Record<string, unknown>, children?: Node[]) => {
const [tagName, ...classes] = tag.split('.')
const element = document.createElement(tagName)
if (classes.length) element.classList.add(...classes)
if (props) {
const listeners = Object.entries(props).filter(([key]) =>
key.startsWith('on')
)
for (const [key, listener] of listeners) {
if (typeof listener === 'function') {
element.addEventListener(key.slice(2), listener as EventListener)
}
}
}
if (children) element.append(...children)
return element
}
}))
vi.mock('../../utils', () => ({
prop: <T>(
target: object,
name: string,
defaultValue: T,
onChanged?: (currentValue: T, previousValue: T) => void
) => {
let currentValue: T
Object.defineProperty(target, name, {
get() {
return currentValue
},
set(newValue: T) {
const previousValue = currentValue
currentValue = newValue
onChanged?.(currentValue, previousValue)
}
})
;(target as Record<string, T>)[name] = defaultValue
return defaultValue
}
}))
import { ComfyButton } from './button'
class MockPopup extends EventTarget {
element = document.createElement('div')
open = false
toggle = vi.fn(() => {
this.open = !this.open
this.dispatchEvent(new CustomEvent('change'))
})
}
function mockApp(settingValue: boolean) {
let listener: (() => void) | undefined
const app = {
ui: {
settings: {
getSettingValue: vi.fn(() => settingValue),
addEventListener: vi.fn((_event: string, callback: () => void) => {
listener = callback
})
}
}
} as unknown as ComfyApp
return {
app,
setSettingValue(value: boolean) {
settingValue = value
listener?.()
}
}
}
describe('ComfyButton', () => {
beforeEach(() => {
document.body.replaceChildren()
})
it('renders icon, content, tooltip, enabled state, and click action', () => {
const action = vi.fn()
const button = new ComfyButton({
icon: 'play',
overIcon: 'pause',
iconSize: 18,
content: 'Run',
tooltip: 'Queue prompt',
enabled: false,
classList: { primary: true, hiddenClass: false },
action
})
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
expect(button.contentElement.textContent).toBe('Run')
expect(button.element.title).toBe('Queue prompt')
expect(button.element.getAttribute('aria-label')).toBe('Queue prompt')
expect(button.element.classList.contains('primary')).toBe(true)
expect(button.element.classList.contains('disabled')).toBe(true)
expect((button.element as HTMLButtonElement).disabled).toBe(true)
button.enabled = true
button.element.dispatchEvent(new MouseEvent('mouseenter'))
expect(button.iconElement.className).toBe('mdi mdi-pause mdi-18px')
button.element.dispatchEvent(new MouseEvent('mouseleave'))
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
button.element.dispatchEvent(new MouseEvent('click'))
expect(action).toHaveBeenCalledWith(expect.any(MouseEvent), button)
})
it('supports HTMLElement content and removing tooltip text', () => {
const button = new ComfyButton({ content: 'Text', tooltip: 'Hint' })
const content = document.createElement('strong')
content.textContent = 'Element'
button.content = content
button.tooltip = ''
expect(button.contentElement.firstElementChild).toBe(content)
expect(button.element.hasAttribute('title')).toBe(false)
})
it('updates the hover icon when overIcon changes while hovered', () => {
const button = new ComfyButton({ icon: 'play' })
button.element.dispatchEvent(new MouseEvent('mouseenter'))
button.overIcon = 'pause'
expect(button.iconElement.className).toBe('mdi mdi-pause')
})
it('hides and shows from a visibility setting', () => {
const settings = mockApp(false)
const button = new ComfyButton({
app: settings.app,
visibilitySetting: {
id: 'Comfy.UseNewMenu',
showValue: true
}
})
expect(button.hidden).toBe(true)
expect(button.element.classList.contains('hidden')).toBe(true)
settings.setSettingValue(true)
expect(button.hidden).toBe(false)
expect(button.element.classList.contains('hidden')).toBe(false)
})
it('toggles click popups and reflects popup open state in classes', () => {
const popup = new MockPopup()
const button = new ComfyButton({ icon: 'dots' }).withPopup(fromAny(popup))
button.element.dispatchEvent(new MouseEvent('click'))
expect(popup.toggle).toHaveBeenCalledOnce()
expect(button.element.classList.contains('popup-open')).toBe(true)
popup.toggle()
expect(button.element.classList.contains('popup-closed')).toBe(true)
})
it('opens hover popups while either the button or popup is hovered', () => {
const popup = new MockPopup()
const button = new ComfyButton({ icon: 'dots' }).withPopup(
fromAny(popup),
'hover'
)
button.element.dispatchEvent(new MouseEvent('mouseenter'))
expect(popup.open).toBe(true)
popup.element.dispatchEvent(new MouseEvent('mouseenter'))
button.element.dispatchEvent(new MouseEvent('mouseleave'))
expect(popup.open).toBe(true)
popup.element.dispatchEvent(new MouseEvent('mouseleave'))
expect(popup.open).toBe(false)
})
it('does not click-toggle a hover popup while hovered', () => {
const popup = new MockPopup()
const button = new ComfyButton({ icon: 'dots' }).withPopup(
fromAny(popup),
'hover'
)
button.element.dispatchEvent(new MouseEvent('mouseenter'))
button.element.dispatchEvent(new MouseEvent('click'))
expect(popup.toggle).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,284 @@
import { fromAny } from '@total-typescript/shoehorn'
import { beforeEach, describe, expect, it, vi } from 'vitest'
type ElChild = Node | string
type ElInput = Record<string, unknown> | ElChild | ElChild[]
function appendChildren(element: HTMLElement, children: ElChild | ElChild[]) {
const list = Array.isArray(children) ? children : [children]
for (const child of list) {
element.append(child)
}
}
vi.mock('../../ui', () => ({
$el: (tag: string, propsOrChildren?: ElInput, children?: ElChild[]) => {
const [tagName, ...classes] = tag.split('.')
const element = document.createElement(tagName)
if (classes.length) element.classList.add(...classes)
if (
propsOrChildren instanceof Node ||
typeof propsOrChildren === 'string' ||
Array.isArray(propsOrChildren)
) {
appendChildren(element, propsOrChildren)
} else if (propsOrChildren) {
for (const [key, value] of Object.entries(propsOrChildren)) {
if (key === '$' && typeof value === 'function') {
value(element)
} else if (key === 'parent' && value instanceof HTMLElement) {
value.append(element)
} else if (key === 'textContent') {
element.textContent = String(value)
} else if (key === 'ariaLabel') {
element.setAttribute('aria-label', String(value))
} else if (key === 'ariaHasPopup') {
element.setAttribute('aria-haspopup', String(value))
} else if (key === 'type') {
element.setAttribute('type', String(value))
} else if (
key.toLowerCase().startsWith('on') &&
typeof value === 'function'
) {
element.addEventListener(
key.slice(2).toLowerCase(),
value as EventListener
)
}
}
}
if (children) appendChildren(element, children)
return element
}
}))
vi.mock('../../utils', () => ({
prop: <T>(
target: object,
name: string,
defaultValue: T,
onChanged?: (currentValue: T, previousValue: T) => void
) => {
let currentValue: T
Object.defineProperty(target, name, {
get() {
return currentValue
},
set(newValue: T) {
const previousValue = currentValue
currentValue = newValue
onChanged?.(currentValue, previousValue)
}
})
;(target as Record<string, T>)[name] = defaultValue
return defaultValue
}
}))
vi.mock('../utils', () => ({
applyClasses: (
element: HTMLElement,
classList: string | Record<string, boolean>,
...baseClasses: string[]
) => {
element.className = baseClasses.join(' ')
if (typeof classList === 'string') {
element.classList.add(...classList.split(' ').filter(Boolean))
} else {
for (const [className, enabled] of Object.entries(classList)) {
element.classList.toggle(className, enabled)
}
}
},
toggleElement:
<T>(
element: HTMLElement,
{
onShow
}: {
onShow?: (element: HTMLElement, value: T) => void
} = {}
) =>
(value: T) => {
element.hidden = !value
if (value) onShow?.(element, value)
}
}))
import { ComfyAsyncDialog } from './asyncDialog'
import { ComfyButton } from './button'
import { ComfyButtonGroup } from './buttonGroup'
import { ComfyPopup } from './popup'
import { ComfySplitButton } from './splitButton'
function targetWithRect(rect: DOMRect) {
const target = document.createElement('button')
vi.spyOn(target, 'getBoundingClientRect').mockReturnValue(rect)
document.body.append(target)
return target
}
describe('ComfyPopup and related UI components', () => {
beforeEach(() => {
document.body.replaceChildren()
vi.restoreAllMocks()
})
it('opens, positions, updates children and classes, and closes from escape', () => {
const target = targetWithRect(
DOMRect.fromRect({ x: 10, y: 20, width: 80, height: 30 })
)
const child = document.createElement('span')
const popup = new ComfyPopup(
{ target, classList: { menu: true, hidden: false } },
child
)
const open = vi.fn()
const close = vi.fn()
const change = vi.fn()
popup.addEventListener('open', open)
popup.addEventListener('close', close)
popup.addEventListener('change', change)
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
DOMRect.fromRect({ height: 20 })
)
popup.open = true
expect(open).toHaveBeenCalledOnce()
expect(change).toHaveBeenCalledOnce()
expect(popup.element).toHaveClass('open')
expect(popup.element.style.getPropertyValue('--left')).toBe('10px')
expect(popup.element.style.getPropertyValue('--bottom')).toBe('35px')
const nextChild = document.createElement('strong')
popup.children = [nextChild]
popup.classList = 'extra'
expect(popup.element.firstElementChild).toBe(nextChild)
expect(popup.element).toHaveClass('comfyui-popup', 'left', 'extra')
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Enter' }))
expect(popup.open).toBe(true)
const escape = new KeyboardEvent('keydown', {
key: 'Escape',
cancelable: true
})
window.dispatchEvent(escape)
expect(close).toHaveBeenCalledOnce()
expect(escape.defaultPrevented).toBe(true)
expect(popup.open).toBe(false)
})
it('handles outside clicks, target clicks, and relative right positioning', () => {
const target = targetWithRect(
DOMRect.fromRect({ x: 100, y: 40, width: 60, height: 25 })
)
const container = document.createElement('section')
document.body.append(container)
const popup = new ComfyPopup({
target,
container,
position: 'relative',
horizontal: 'right'
})
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
DOMRect.fromRect({ height: 100 })
)
Object.defineProperty(popup.element, 'clientWidth', {
configurable: true,
value: 40
})
popup.open = true
target.dispatchEvent(new MouseEvent('click', { bubbles: true }))
expect(popup.open).toBe(true)
const outside = document.createElement('div')
document.body.append(outside)
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
expect(popup.open).toBe(false)
expect(popup.element.style.getPropertyValue('--left')).toBe('0px')
expect(popup.element.style.getPropertyValue('--top')).toBe('25px')
})
it('keeps outside clicks open when target clicks are not ignored', () => {
const target = targetWithRect(DOMRect.fromRect({ height: 10 }))
const popup = new ComfyPopup({
target,
ignoreTarget: false,
closeOnEscape: false
})
const outside = document.createElement('div')
document.body.append(outside)
popup.toggle()
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' }))
expect(popup.open).toBe(true)
})
it('renders split button popup items and updates button groups', () => {
const primary = new ComfyButton({ content: 'Queue' })
const itemButton = new ComfyButton({ content: 'Queue front' })
const rawItem = document.createElement('button')
rawItem.textContent = 'Queue back'
const split = new ComfySplitButton(
{
primary,
mode: 'hover',
horizontal: 'right',
position: 'absolute'
},
itemButton,
rawItem
)
expect(split.element).toHaveClass('comfyui-split-button', 'hover')
expect(split.popup.element).toHaveTextContent('Queue front')
expect(split.popup.element).toHaveTextContent('Queue back')
expect(document.body).toContainElement(split.popup.element)
const group = new ComfyButtonGroup(primary)
group.append(itemButton)
group.insert(fromAny(rawItem), 1)
expect(group.element.children).toHaveLength(3)
expect(group.remove(itemButton)).toEqual([itemButton])
expect(group.remove(itemButton)).toBeUndefined()
expect(group.element.children).toHaveLength(2)
})
it('resolves async dialogs from buttons, close events, and prompt actions', async () => {
const dialog = new ComfyAsyncDialog<number>([
{ text: 'Seven', value: 7 },
'Fallback'
])
const promise = dialog.show('Pick one')
dialog.element.querySelector<HTMLButtonElement>('button')?.click()
await expect(promise).resolves.toBe(7)
const closePromise = dialog.show(document.createElement('em'))
dialog.element.dispatchEvent(new Event('close'))
await expect(closePromise).resolves.toBeNull()
const promptPromise = ComfyAsyncDialog.prompt({
title: 'Confirm',
message: 'Continue?',
actions: [{ text: 'Yes', value: 'yes' }]
})
const prompt = Array.from(document.querySelectorAll('dialog')).at(-1)
expect(prompt).toHaveTextContent('Confirm')
prompt?.querySelector<HTMLButtonElement>('button')?.click()
await expect(promptPromise).resolves.toBe('yes')
expect(document.body).not.toContainElement(prompt ?? null)
})
})

View File

@@ -0,0 +1,69 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import { ComfyDialog } from './dialog'
const mocks = vi.hoisted(() => ({
$el: (
selector: string,
propsOrChildren?: Record<string, unknown> | Node[],
maybeChildren?: Node[]
) => {
const [tag, ...classes] = selector.split('.')
const element = document.createElement(tag || 'div')
element.classList.add(...classes.filter(Boolean))
const children = Array.isArray(propsOrChildren)
? propsOrChildren
: maybeChildren
if (propsOrChildren && !Array.isArray(propsOrChildren)) {
for (const [key, value] of Object.entries(propsOrChildren)) {
if (key === 'parent' && value instanceof Node) {
value.appendChild(element)
} else if (key === '$' && typeof value === 'function') {
value(element)
} else {
Reflect.set(element, key, value)
}
}
}
element.append(...(children ?? []))
return element
}
}))
vi.mock('../ui', () => ({
$el: mocks.$el
}))
describe('ComfyDialog', () => {
afterEach(() => {
document.body.replaceChildren()
})
it('shows string and element content and closes through the default button', () => {
const dialog = new ComfyDialog()
dialog.show('<strong>Hello</strong>')
expect(dialog.element.style.display).toBe('flex')
expect(dialog.textElement.innerHTML).toBe('<strong>Hello</strong>')
dialog.element.querySelector('button')?.click()
expect(dialog.element.style.display).toBe('none')
const first = document.createElement('span')
const second = document.createElement('em')
dialog.show([first, second])
expect([...dialog.textElement.children]).toEqual([first, second])
})
it('uses supplied custom buttons', () => {
const button = document.createElement('button')
const dialog = new ComfyDialog('section', [button])
expect(dialog.element.tagName).toBe('SECTION')
expect(dialog.element.querySelector('button')).toBe(button)
})
})

View File

@@ -0,0 +1,216 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import { DraggableList } from './draggableList'
function createList(itemCount: number) {
const container = document.createElement('div')
const items = Array.from({ length: itemCount }, (_, index) => {
const item = document.createElement('div')
item.className = 'item'
item.dataset.index = String(index)
const handle = document.createElement('button')
handle.className = 'drag-handle'
item.append(handle)
container.append(item)
return item
})
document.body.append(container)
return { container, items }
}
function setRect(element: Element, top: number, height = 20) {
return vi
.spyOn(element, 'getBoundingClientRect')
.mockReturnValue(new DOMRect(0, top, 100, height))
}
function defineScrollMetrics(
container: HTMLElement,
scrollHeight: number,
clientHeight: number
) {
Object.defineProperty(container, 'scrollHeight', {
configurable: true,
value: scrollHeight
})
Object.defineProperty(container, 'clientHeight', {
configurable: true,
value: clientHeight
})
}
function mouseDragEvent(
target: Element,
overrides: Partial<MouseEvent> = {}
): MouseEvent {
return {
button: 0,
clientX: 0,
clientY: 0,
preventDefault: vi.fn(),
target,
...overrides
} satisfies Partial<MouseEvent> as MouseEvent
}
describe('DraggableList', () => {
afterEach(() => {
document.body.replaceChildren()
vi.restoreAllMocks()
vi.unstubAllGlobals()
})
it('ignores missing containers and non-primary drag starts', () => {
const listWithoutContainer = new DraggableList(null, '.item')
const { container, items } = createList(1)
const list = new DraggableList(container, '.item')
list.dragStart(
mouseDragEvent(items[0].querySelector('.drag-handle')!, { button: 1 })
)
list.dragEnd()
expect(listWithoutContainer.listContainer).toBeNull()
expect(list.draggableItem).toBeUndefined()
})
it('starts from a handle, scrolls downward, and reorders upward', () => {
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
callback(0)
return 1
})
const { container, items } = createList(3)
const scrollBy = vi.fn((_left: number, top: number) => {
container.scrollTop += top
})
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
defineScrollMetrics(container, 120, 80)
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
new DOMRect(0, 0, 100, 80)
)
setRect(items[0], 0)
setRect(items[1], 30)
setRect(items[2], 60)
const list = new DraggableList(container, '.item')
const dragStart = vi.fn()
const dragEnd = vi.fn()
list.addEventListener('dragstart', dragStart)
list.addEventListener('dragend', dragEnd)
list.dragStart(
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
clientX: 10,
clientY: 70
})
)
list.drag(
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
clientX: 20,
clientY: 100
})
)
items[1].dataset.isToggled = ''
list.dragEnd()
expect(dragStart).toHaveBeenCalledOnce()
expect(dragEnd).toHaveBeenCalledOnce()
expect(scrollBy).toHaveBeenCalledWith(0, 10)
expect([...container.children]).toEqual([items[0], items[2], items[1]])
expect(items[0].classList.contains('is-idle')).toBe(true)
expect(items[1].classList.contains('is-idle')).toBe(true)
})
it('supports touch coordinates, upward scrolling, and downward reorder', () => {
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
callback(0)
return 1
})
const { container, items } = createList(3)
const scrollBy = vi.fn((_left: number, top: number) => {
container.scrollTop += top
})
container.scrollTop = 10
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
defineScrollMetrics(container, 120, 80)
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
new DOMRect(0, 20, 100, 80)
)
setRect(items[0], 0)
setRect(items[1], 30)
setRect(items[2], 60)
const list = new DraggableList(container, '.item')
const touchStart = {
button: 0,
clientX: 0,
clientY: 0,
preventDefault: vi.fn(),
target: items[0].querySelector('.drag-handle')!,
touches: [{ clientX: 5, clientY: 30 }]
} as unknown as TouchEvent
const touchMove = {
clientX: 0,
clientY: 0,
preventDefault: vi.fn(),
target: items[0].querySelector('.drag-handle')!,
touches: [{ clientX: 8, clientY: 10 }]
} as unknown as TouchEvent
list.dragStart(touchStart)
list.drag(touchMove)
items[1].dataset.isToggled = ''
list.dragEnd()
expect(scrollBy).toHaveBeenCalledWith(0, -10)
expect([...container.children]).toEqual([items[1], items[0], items[2]])
})
it('updates idle item state around the dragged item midpoint', () => {
const { container, items } = createList(3)
const list = new DraggableList(container, '.item')
const state = list as unknown as {
items: HTMLElement[]
draggableItem: HTMLElement
}
state.items = items
state.draggableItem = items[1]
list.itemsGap = 5
items[0].classList.add('is-idle')
items[1].classList.add('is-idle')
items[2].classList.add('is-idle')
items[0].dataset.isAbove = ''
const draggedRect = setRect(items[1], -10)
setRect(items[0], 0)
setRect(items[2], 60)
list.updateIdleItemsStateAndPosition()
expect(items[0].dataset.isToggled).toBe('')
expect(items[0].style.transform).toBe('translateY(25px)')
expect(items[2].style.transform).toBe('')
draggedRect.mockReturnValue(new DOMRect(0, 100, 100, 20))
list.updateIdleItemsStateAndPosition()
expect(items[0].dataset.isToggled).toBeUndefined()
expect(items[2].dataset.isToggled).toBe('')
expect(items[2].style.transform).toBe('translateY(-25px)')
})
it('uses zero gap for short lists and disposes listeners', () => {
const { container } = createList(1)
const list = new DraggableList(container, '.item')
const off = vi.fn()
const disposableList = list as unknown as { off: Array<() => void> }
disposableList.off = [off]
list.setItemsGap()
list.dispose()
expect(list.itemsGap).toBe(0)
expect(off).toHaveBeenCalledOnce()
})
})

View File

@@ -0,0 +1,33 @@
import { describe, expect, it } from 'vitest'
import { calculateImageGrid } from './imagePreview'
function createImage(width: number, height: number) {
const img = document.createElement('img')
Object.defineProperty(img, 'naturalWidth', {
configurable: true,
value: width
})
Object.defineProperty(img, 'naturalHeight', {
configurable: true,
value: height
})
return img
}
describe('imagePreview', () => {
it('calculates the highest-area grid', () => {
const images = [
createImage(100, 100),
createImage(100, 100),
createImage(100, 100)
]
expect(calculateImageGrid(images, 300, 120)).toMatchObject({
cellWidth: 100,
cellHeight: 100,
cols: 3,
rows: 1
})
})
})

View File

@@ -0,0 +1,86 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import { toggleSwitch } from './toggleSwitch'
const mocks = vi.hoisted(() => ({
$el: (
selector: string,
propsOrChildren?: Record<string, unknown> | Node[] | Node,
maybeChildren?: Node[] | Node
) => {
const [tag, ...classes] = selector.split('.')
const element = document.createElement(tag || 'div')
element.classList.add(...classes.filter(Boolean))
const children = Array.isArray(propsOrChildren)
? propsOrChildren
: propsOrChildren instanceof Node
? [propsOrChildren]
: Array.isArray(maybeChildren)
? maybeChildren
: maybeChildren instanceof Node
? [maybeChildren]
: []
if (
propsOrChildren &&
!(propsOrChildren instanceof Node) &&
!Array.isArray(propsOrChildren)
) {
for (const [key, value] of Object.entries(propsOrChildren)) {
Reflect.set(element, key, value)
}
}
element.append(...children)
return element
}
}))
vi.mock('../ui', () => ({
$el: mocks.$el
}))
describe('toggleSwitch', () => {
afterEach(() => {
document.body.replaceChildren()
})
it('selects the first item when none is preselected', () => {
const onChange = vi.fn()
const container = toggleSwitch('mode', ['first', 'second'], { onChange })
const labels = [...container.querySelectorAll('label')]
const inputs = [...container.querySelectorAll('input')]
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(true)
expect((inputs[0] as HTMLInputElement).checked).toBe(true)
expect(onChange).toHaveBeenCalledWith({
item: 'first',
prev: undefined
})
})
it('moves selection and reports the previous item', () => {
const onChange = vi.fn()
const container = toggleSwitch(
'mode',
[
{ text: 'first', tooltip: 'First option' },
{ text: 'second', value: '2' }
],
{ onChange }
)
const labels = [...container.querySelectorAll('label')]
const secondInput = labels[1].querySelector('input') as HTMLInputElement
secondInput.onchange?.(new Event('change'))
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(false)
expect(labels[1].classList.contains('comfy-toggle-selected')).toBe(true)
expect(labels[0].title).toBe('First option')
expect(secondInput.value).toBe('2')
expect(onChange).toHaveBeenLastCalledWith({
item: { text: 'second', value: '2' },
prev: { text: 'first', tooltip: 'First option', value: 'first' }
})
})
})

View File

@@ -0,0 +1,45 @@
import { describe, expect, it, vi } from 'vitest'
import { applyClasses, toggleElement } from './utils'
describe('ui utils', () => {
it('applies string, array, object, and required classes', () => {
const element = document.createElement('div')
applyClasses(element, 'one two', 'required')
expect([...element.classList]).toEqual(['one', 'two', 'required'])
applyClasses(element, ['three', 'four'])
expect([...element.classList]).toEqual(['three', 'four'])
applyClasses(element, { five: true, six: false, seven: true })
expect([...element.classList]).toEqual(['five', 'seven'])
applyClasses(element, null as unknown as string)
expect(element.className).toBe('')
})
it('toggles an element through a placeholder', () => {
const parent = document.createElement('div')
const element = document.createElement('span')
const onHide = vi.fn()
const onShow = vi.fn()
parent.append(element)
const toggle = toggleElement(element, { onHide, onShow })
toggle(false)
expect(parent.firstChild).toBeInstanceOf(Comment)
expect(onHide).toHaveBeenCalledWith(element)
toggle(true)
expect(parent.firstChild).toBe(element)
expect(onShow).toHaveBeenCalledWith(element, true)
toggle('visible')
expect(onShow).toHaveBeenCalledWith(element, 'visible')
toggle(false)
expect(parent.firstChild).toBeInstanceOf(Comment)
expect(onHide).toHaveBeenCalledTimes(2)
})
})

127
src/scripts/utils.test.ts Normal file
View File

@@ -0,0 +1,127 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import {
addStylesheet,
clone,
getStorageValue,
prop,
setStorageValue
} from './utils'
interface LinkAttrs {
href: string
onerror: (error: Event) => void
onload: () => void
parent: HTMLElement
rel: string
type: string
}
const mocks = vi.hoisted(() => ({
api: {
clientId: null as string | null,
initialClientId: null as string | null
},
$el: vi.fn()
}))
vi.mock('./api', () => ({
api: mocks.api
}))
vi.mock('./ui', () => ({
$el: mocks.$el
}))
function lastLinkAttrs() {
return mocks.$el.mock.calls.at(-1)?.[1] as LinkAttrs
}
describe('scripts utils', () => {
afterEach(() => {
localStorage.clear()
sessionStorage.clear()
mocks.api.clientId = null
mocks.api.initialClientId = null
mocks.$el.mockReset()
vi.unstubAllGlobals()
})
it('clones with structuredClone and falls back to JSON cloning', () => {
const source = { nested: { value: 1 } }
expect(clone(source)).toEqual(source)
vi.stubGlobal(
'structuredClone',
vi.fn(() => {
throw new Error('unsupported')
})
)
const cloned = clone(source)
cloned.nested.value = 2
expect(cloned).toEqual({ nested: { value: 2 } })
expect(source).toEqual({ nested: { value: 1 } })
})
it('adds stylesheets from script and relative URLs', async () => {
const scriptPromise = addStylesheet('/extensions/example.js')
lastLinkAttrs().onload()
await expect(scriptPromise).resolves.toBeUndefined()
expect(lastLinkAttrs()).toMatchObject({
href: '/extensions/example.css',
parent: document.head,
rel: 'stylesheet',
type: 'text/css'
})
const cssPromise = addStylesheet('theme.css', 'https://example.com/base/')
lastLinkAttrs().onload()
await expect(cssPromise).resolves.toBeUndefined()
expect(lastLinkAttrs().href).toBe('https://example.com/base/theme.css')
})
it('rejects when stylesheet loading fails', async () => {
const promise = addStylesheet('missing.css', 'https://example.com/')
const error = new Event('error')
lastLinkAttrs().onerror(error)
await expect(promise).rejects.toBe(error)
})
it('defines an observable property with the supplied default', () => {
const target = {}
const onChanged = vi.fn()
expect(prop(target, 'mode', 'initial', onChanged)).toBe('initial')
Object.assign(target, { mode: 'next' })
expect((target as { mode: string }).mode).toBe('next')
expect(onChanged).toHaveBeenCalledWith('next', undefined, target, 'mode')
})
it('uses client-scoped storage before local fallback', () => {
mocks.api.clientId = 'client-1'
setStorageValue('setting', 'client-value')
sessionStorage.removeItem('setting:client-1')
expect(getStorageValue('setting')).toBe('client-value')
expect(localStorage.getItem('setting')).toBe('client-value')
sessionStorage.setItem('setting:client-1', 'session-value')
expect(getStorageValue('setting')).toBe('session-value')
})
it('uses initial client id when the current client id is unavailable', () => {
mocks.api.initialClientId = 'initial-1'
setStorageValue('setting', 'initial-value')
expect(sessionStorage.getItem('setting:initial-1')).toBe('initial-value')
expect(getStorageValue('setting')).toBe('initial-value')
})
})

Some files were not shown because too many files have changed in this diff Show More