mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-03 13:48:49 +00:00
Compare commits
6 Commits
test/fixme
...
codex/cove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7141bda563 | ||
|
|
74d4366994 | ||
|
|
c85c0585ab | ||
|
|
804630fa02 | ||
|
|
7417864353 | ||
|
|
1f26dc2b57 |
2
.github/workflows/ci-tests-e2e-coverage.yaml
vendored
2
.github/workflows/ci-tests-e2e-coverage.yaml
vendored
@@ -121,7 +121,7 @@ jobs:
|
||||
--title "ComfyUI E2E Coverage" \
|
||||
--no-function-coverage \
|
||||
--precision 1 \
|
||||
--ignore-errors source,unmapped \
|
||||
--ignore-errors source,unmapped,range \
|
||||
--synthesize-missing
|
||||
|
||||
- name: Upload HTML report artifact
|
||||
|
||||
1
.github/workflows/ci-tests-storybook.yaml
vendored
1
.github/workflows/ci-tests-storybook.yaml
vendored
@@ -95,7 +95,6 @@ jobs:
|
||||
if: |
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& github.event.pull_request.head.repo.fork == false
|
||||
&& startsWith(github.head_ref, 'version-bump-')
|
||||
&& (needs.changes.outputs.storybook-changes == 'true'
|
||||
|| needs.changes.outputs.app-frontend-changes == 'true'
|
||||
|
||||
@@ -30,7 +30,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
deploy-preview:
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.fork == false
|
||||
if: github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
10
.github/workflows/ci-website-e2e.yaml
vendored
10
.github/workflows/ci-website-e2e.yaml
vendored
@@ -67,15 +67,7 @@ jobs:
|
||||
|
||||
- name: Deploy report to Cloudflare
|
||||
id: deploy
|
||||
if: >-
|
||||
${{
|
||||
always() &&
|
||||
!cancelled() &&
|
||||
(
|
||||
github.event_name != 'pull_request' ||
|
||||
github.event.pull_request.head.repo.fork == false
|
||||
)
|
||||
}}
|
||||
if: always() && !cancelled()
|
||||
env:
|
||||
CLOUDFLARE_API_TOKEN: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
CLOUDFLARE_ACCOUNT_ID: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
|
||||
13
.github/workflows/cloud-dispatch-build.yaml
vendored
13
.github/workflows/cloud-dispatch-build.yaml
vendored
@@ -32,13 +32,12 @@ jobs:
|
||||
if: >
|
||||
github.repository == 'Comfy-Org/ComfyUI_frontend' &&
|
||||
(github.event_name != 'pull_request' ||
|
||||
(github.event.pull_request.head.repo.fork == false &&
|
||||
((github.event.action == 'labeled' &&
|
||||
contains(fromJSON('["preview","preview-cpu","preview-gpu"]'), github.event.label.name)) ||
|
||||
(github.event.action == 'synchronize' &&
|
||||
(contains(github.event.pull_request.labels.*.name, 'preview') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'preview-cpu') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'preview-gpu'))))))
|
||||
(github.event.action == 'labeled' &&
|
||||
contains(fromJSON('["preview","preview-cpu","preview-gpu"]'), github.event.label.name)) ||
|
||||
(github.event.action == 'synchronize' &&
|
||||
(contains(github.event.pull_request.labels.*.name, 'preview') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'preview-cpu') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'preview-gpu'))))
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Build client payload
|
||||
|
||||
@@ -21,7 +21,6 @@ jobs:
|
||||
# - Preview label specifically removed
|
||||
if: >
|
||||
github.repository == 'Comfy-Org/ComfyUI_frontend' &&
|
||||
github.event.pull_request.head.repo.fork == false &&
|
||||
((github.event.action == 'closed' &&
|
||||
(contains(github.event.pull_request.labels.*.name, 'preview') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'preview-cpu') ||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 59 KiB After Width: | Height: | Size: 59 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 58 KiB After Width: | Height: | Size: 58 KiB |
@@ -1,3 +1,3 @@
|
||||
<svg width="20" height="32" viewBox="0 0 20 32" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M20 32V0C20 5.39616 15.5172 9.78053 10 9.78053C4.48276 9.78053 0 5.416 0 0V32C0 26.6038 4.48276 22.2195 10 22.2195C15.5172 22.2195 20 26.6038 20 32Z" fill="#F2FF59"/>
|
||||
<svg preserveAspectRatio="none" width="100%" height="100%" overflow="visible" style="display: block;" viewBox="0 0 20 32" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path id="Vector" d="M20 32V0C20 5.39616 15.5172 9.78053 10 9.78053C4.48276 9.78053 0 5.416 0 0V32C0 26.6038 4.48276 22.2195 10 22.2195C15.5172 22.2195 20 26.6038 20 32Z" fill="var(--fill-0, #F2FF59)"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 279 B After Width: | Height: | Size: 380 B |
@@ -28,12 +28,7 @@ const APP_URL = process.env.PLAYWRIGHT_TEST_URL || 'http://localhost:8188'
|
||||
// matches it against the members self-row.
|
||||
const SELF_EMAIL = 'e2e@test.comfy.org'
|
||||
|
||||
// consolidated_billing_enabled routes personal workspaces to the unified
|
||||
// pricing table asserted here; without it they fall back to the legacy table.
|
||||
const BOOT_FEATURES = {
|
||||
team_workspaces_enabled: true,
|
||||
consolidated_billing_enabled: true
|
||||
} satisfies RemoteConfig
|
||||
const BOOT_FEATURES = { team_workspaces_enabled: true } satisfies RemoteConfig
|
||||
// Disable the experimental Asset API: with it on (cloud default) the unmocked
|
||||
// asset endpoints 403 and workflow restore throws uncaught, aborting the
|
||||
// GraphCanvas onMounted chain before the deep-link loader.
|
||||
|
||||
@@ -101,8 +101,7 @@ test.describe('Shared workflow missing media', { tag: '@cloud' }, () => {
|
||||
})
|
||||
})
|
||||
|
||||
// FIXME: flaky (burst-fails all CI retries some runs; more frequent on cloud/1.45's heavier init). Needs de-flaking.
|
||||
test.fixme('imports shared media before loading workflow so missing media is not surfaced', async ({
|
||||
test('imports shared media before loading workflow so missing media is not surfaced', async ({
|
||||
comfyPage,
|
||||
sharedWorkflowImportMocks
|
||||
}) => {
|
||||
|
||||
@@ -166,8 +166,7 @@ test.describe('Assets sidebar - media type filter', { tag: '@cloud' }, () => {
|
||||
await expect(tab.getAssetCardByName(videoCardName)).toBeVisible()
|
||||
})
|
||||
|
||||
// FIXME: flaky (burst-fails all CI retries some runs; more frequent on cloud/1.45's heavier init). Needs de-flaking.
|
||||
test.fixme('Selecting only "Audio" hides non-audio assets', async ({
|
||||
test('Selecting only "Audio" hides non-audio assets', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.assetsTab
|
||||
@@ -193,8 +192,7 @@ test.describe('Assets sidebar - media type filter', { tag: '@cloud' }, () => {
|
||||
await expect(tab.getAssetCardByName(threeDCardName)).toBeVisible()
|
||||
})
|
||||
|
||||
// FIXME: flaky (burst-fails all CI retries some runs; more frequent on cloud/1.45's heavier init). Needs de-flaking.
|
||||
test.fixme('Multiple filters combine via OR (image + video)', async ({
|
||||
test('Multiple filters combine via OR (image + video)', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.assetsTab
|
||||
@@ -212,8 +210,7 @@ test.describe('Assets sidebar - media type filter', { tag: '@cloud' }, () => {
|
||||
await expect(tab.getAssetCardByName(threeDCardName)).toHaveCount(0)
|
||||
})
|
||||
|
||||
// FIXME: flaky (burst-fails all CI retries some runs; more frequent on cloud/1.45's heavier init). Needs de-flaking.
|
||||
test.fixme('Unchecking the active filter restores previously hidden cards', async ({
|
||||
test('Unchecking the active filter restores previously hidden cards', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.assetsTab
|
||||
|
||||
@@ -70,8 +70,7 @@ test.describe(
|
||||
await comfyPage.assets.clearMocks()
|
||||
})
|
||||
|
||||
// FIXME: flaky (burst-fails all CI retries some runs; more frequent on cloud/1.45's heavier init). Needs de-flaking.
|
||||
test.fixme('renders one tile per unique composite key', async ({
|
||||
test('renders one tile per unique composite key', async ({
|
||||
comfyPage
|
||||
}, testInfo) => {
|
||||
const tab = comfyPage.menu.assetsTab
|
||||
|
||||
@@ -409,8 +409,7 @@ test.describe('Vue Node Moving', { tag: '@vue-nodes' }, () => {
|
||||
await expect.poll(getGroupPos).not.toEqual(initialGroupPos)
|
||||
})
|
||||
|
||||
// FIXME: flaky (burst-fails all CI retries some runs; more frequent on cloud/1.45's heavier init). Needs de-flaking.
|
||||
test.fixme(
|
||||
test(
|
||||
'@mobile should allow moving nodes by dragging on touch devices',
|
||||
{ tag: '@screenshot' },
|
||||
async ({ comfyPage }) => {
|
||||
|
||||
75
src/base/common/async.test.ts
Normal file
75
src/base/common/async.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
describe('runWhenGlobalIdle', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('falls back to a timeout when idle callbacks are unavailable', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.stubGlobal('requestIdleCallback', undefined)
|
||||
vi.stubGlobal('cancelIdleCallback', undefined)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
const disposable = runWhenGlobalIdle(runner)
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
expect(runner).toHaveBeenCalledOnce()
|
||||
const deadline = runner.mock.calls[0][0]
|
||||
expect(deadline.didTimeout).toBe(true)
|
||||
expect(deadline.timeRemaining()).toBeGreaterThanOrEqual(0)
|
||||
|
||||
disposable.dispose()
|
||||
disposable.dispose()
|
||||
})
|
||||
|
||||
it('cancels fallback idle work before it runs', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.stubGlobal('requestIdleCallback', undefined)
|
||||
vi.stubGlobal('cancelIdleCallback', undefined)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
runWhenGlobalIdle(runner).dispose()
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
expect(runner).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses native idle callbacks when available', async () => {
|
||||
const requestIdleCallback = vi.fn(() => 42)
|
||||
const cancelIdleCallback = vi.fn()
|
||||
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
|
||||
vi.stubGlobal('cancelIdleCallback', cancelIdleCallback)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
const disposable = runWhenGlobalIdle(runner, 250)
|
||||
|
||||
expect(requestIdleCallback).toHaveBeenCalledWith(runner, { timeout: 250 })
|
||||
|
||||
disposable.dispose()
|
||||
disposable.dispose()
|
||||
|
||||
expect(cancelIdleCallback).toHaveBeenCalledOnce()
|
||||
expect(cancelIdleCallback).toHaveBeenCalledWith(42)
|
||||
})
|
||||
|
||||
it('omits native idle timeout options when no timeout is supplied', async () => {
|
||||
const requestIdleCallback = vi.fn(() => 7)
|
||||
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
|
||||
vi.stubGlobal('cancelIdleCallback', vi.fn())
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
runWhenGlobalIdle(runner)
|
||||
|
||||
expect(requestIdleCallback).toHaveBeenCalledWith(runner, undefined)
|
||||
})
|
||||
})
|
||||
@@ -122,6 +122,22 @@ describe('downloadUtil', () => {
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws for an empty URL', () => {
|
||||
expect(() => downloadFile('')).toThrow(
|
||||
'Invalid URL provided for download'
|
||||
)
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws for a whitespace URL', () => {
|
||||
expect(() => downloadFile(' ')).toThrow(
|
||||
'Invalid URL provided for download'
|
||||
)
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should prefer custom filename over extracted filename', () => {
|
||||
const testUrl =
|
||||
'https://example.com/api/file?filename=extracted-image.jpg'
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
CREDITS_PER_USD,
|
||||
COMFY_CREDIT_RATE_CENTS,
|
||||
centsToCredits,
|
||||
clampUsd,
|
||||
creditsToCents,
|
||||
creditsToUsd,
|
||||
formatCredits,
|
||||
@@ -43,4 +44,21 @@ describe('comfyCredits helpers', () => {
|
||||
expect(formatCreditsFromUsd({ usd: 1, locale })).toBe('211.00')
|
||||
expect(formatUsd({ value: 4.2, locale })).toBe('4.20')
|
||||
})
|
||||
|
||||
test('formats with compatible fraction digit bounds', () => {
|
||||
expect(
|
||||
formatCredits({
|
||||
value: 12.345,
|
||||
locale: 'en-US',
|
||||
numberOptions: { minimumFractionDigits: 4, maximumFractionDigits: 2 }
|
||||
})
|
||||
).toBe('12.35')
|
||||
})
|
||||
|
||||
test('clamps USD purchase values into the supported range', () => {
|
||||
expect(clampUsd(Number.NaN)).toBe(0)
|
||||
expect(clampUsd(-5)).toBe(1)
|
||||
expect(clampUsd(42)).toBe(42)
|
||||
expect(clampUsd(5000)).toBe(1000)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -158,8 +158,8 @@ import { creditsToUsd, usdToCredits } from '@/base/credits/comfyCredits'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import FormattedNumberStepper from '@/components/ui/stepper/FormattedNumberStepper.vue'
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useBillingRouting } from '@/composables/billing/useBillingRouting'
|
||||
import { useExternalLink } from '@/composables/useExternalLink'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { clearTopupTracking } from '@/platform/telemetry/topupTracker'
|
||||
@@ -178,7 +178,7 @@ const settingsDialog = useSettingsDialog()
|
||||
const telemetry = useTelemetry()
|
||||
const toast = useToast()
|
||||
const { buildDocsUrl, docsPaths } = useExternalLink()
|
||||
const { shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
const { flags } = useFeatureFlags()
|
||||
|
||||
const { isSubscriptionEnabled } = useSubscription()
|
||||
// Constants
|
||||
@@ -260,9 +260,9 @@ async function handleBuy() {
|
||||
// Close top-up dialog (keep tracking) and open credits panel to show updated balance
|
||||
handleClose(false)
|
||||
|
||||
// On the consolidated (workspace) billing flow, show the workspace settings
|
||||
// panel; otherwise show the legacy subscription/credits panel.
|
||||
const settingsPanel = shouldUseWorkspaceBilling.value
|
||||
// In workspace mode (personal workspace), show workspace settings panel
|
||||
// Otherwise, show legacy subscription/credits panel
|
||||
const settingsPanel = flags.teamWorkspacesEnabled
|
||||
? 'workspace'
|
||||
: isSubscriptionEnabled()
|
||||
? 'subscription'
|
||||
|
||||
@@ -2,11 +2,12 @@ import { createTestingPinia } from '@pinia/testing'
|
||||
import PrimeVue from 'primevue/config'
|
||||
import Tooltip from 'primevue/tooltip'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, nextTick, onMounted, ref } from 'vue'
|
||||
import { defineComponent, onMounted, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import { render, screen, waitFor } from '@testing-library/vue'
|
||||
|
||||
import type * as DistributionTypes from '@/platform/distribution/types'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
import { EventType } from '@/services/customerEventsService'
|
||||
|
||||
@@ -34,29 +35,19 @@ vi.mock('@/services/customerEventsService', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
const mockTelemetry = vi.hoisted(() => ({
|
||||
checkForCompletedTopup: vi.fn()
|
||||
}))
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => mockTelemetry
|
||||
useTelemetry: () => null
|
||||
}))
|
||||
|
||||
const mockBillingRouting = vi.hoisted(() => ({
|
||||
shouldUseWorkspaceBilling: false
|
||||
const mockFlags = vi.hoisted(() => ({ teamWorkspacesEnabled: false }))
|
||||
vi.mock('@/composables/useFeatureFlags', () => ({
|
||||
useFeatureFlags: () => ({ flags: mockFlags })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', async (importOriginal) => ({
|
||||
...(await importOriginal<typeof DistributionTypes>()),
|
||||
isCloud: true
|
||||
}))
|
||||
vi.mock('@/composables/billing/useBillingRouting', async () => {
|
||||
const { ref } = await import('vue')
|
||||
const shouldUseWorkspaceBilling = ref(false)
|
||||
Object.defineProperty(mockBillingRouting, 'shouldUseWorkspaceBilling', {
|
||||
get: () => shouldUseWorkspaceBilling.value,
|
||||
set: (value: boolean) => {
|
||||
shouldUseWorkspaceBilling.value = value
|
||||
}
|
||||
})
|
||||
return {
|
||||
useBillingRouting: () => ({ shouldUseWorkspaceBilling })
|
||||
}
|
||||
})
|
||||
|
||||
const mockWorkspaceApi = vi.hoisted(() => ({
|
||||
getBillingEvents: vi.fn()
|
||||
@@ -77,10 +68,7 @@ const i18n = createI18n({
|
||||
additionalInfo: 'Additional Info',
|
||||
added: 'Added',
|
||||
accountInitialized: 'Account initialized',
|
||||
model: 'Model',
|
||||
loadEventsError: 'Failed to load activity. Please try again.',
|
||||
loadEventsUnknownError:
|
||||
'Something went wrong while loading activity. Please refresh and try again.'
|
||||
model: 'Model'
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,11 +95,6 @@ const AutoRefreshWrapper = defineComponent({
|
||||
template: '<UsageLogsTable ref="tableRef" />'
|
||||
})
|
||||
|
||||
async function flushMicrotasks() {
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
await nextTick()
|
||||
}
|
||||
|
||||
function makeEventsResponse(
|
||||
events: Partial<AuditLog>[],
|
||||
overrides: Record<string, unknown> = {}
|
||||
@@ -154,7 +137,7 @@ describe('UsageLogsTable', () => {
|
||||
|
||||
mockCustomerEventsService.getMyEvents.mockResolvedValue(mockEventsResponse)
|
||||
mockWorkspaceApi.getBillingEvents.mockResolvedValue(mockEventsResponse)
|
||||
mockBillingRouting.shouldUseWorkspaceBilling = false
|
||||
mockFlags.teamWorkspacesEnabled = false
|
||||
mockCustomerEventsService.formatEventType.mockImplementation(
|
||||
(type: string) => {
|
||||
switch (type) {
|
||||
@@ -245,7 +228,7 @@ describe('UsageLogsTable', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('shows a localized fallback instead of a raw Error message', async () => {
|
||||
it('shows error message when service throws', async () => {
|
||||
mockCustomerEventsService.getMyEvents.mockRejectedValue(
|
||||
new Error('Network error')
|
||||
)
|
||||
@@ -253,25 +236,7 @@ describe('UsageLogsTable', () => {
|
||||
renderWithAutoRefresh()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Something went wrong while loading activity. Please refresh and try again.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.queryByText('Network error')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows a localized fallback when the service reports no message', async () => {
|
||||
mockCustomerEventsService.getMyEvents.mockResolvedValue(null)
|
||||
mockCustomerEventsService.error.value = null
|
||||
|
||||
renderWithAutoRefresh()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText('Failed to load activity. Please try again.')
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByText('Network error')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -376,8 +341,8 @@ describe('UsageLogsTable', () => {
|
||||
})
|
||||
|
||||
describe('billing events source', () => {
|
||||
it('uses workspaceApi.getBillingEvents on the workspace billing flow', async () => {
|
||||
mockBillingRouting.shouldUseWorkspaceBilling = true
|
||||
it('uses workspaceApi.getBillingEvents when teamWorkspacesEnabled is on', async () => {
|
||||
mockFlags.teamWorkspacesEnabled = true
|
||||
|
||||
await renderLoaded()
|
||||
|
||||
@@ -387,90 +352,6 @@ describe('UsageLogsTable', () => {
|
||||
})
|
||||
expect(mockCustomerEventsService.getMyEvents).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('discards a stale legacy response when routing flips mid-fetch', async () => {
|
||||
let resolveLegacy!: (value: ReturnType<typeof makeEventsResponse>) => void
|
||||
mockCustomerEventsService.getMyEvents.mockReturnValue(
|
||||
new Promise((resolve) => {
|
||||
resolveLegacy = resolve
|
||||
})
|
||||
)
|
||||
mockWorkspaceApi.getBillingEvents.mockResolvedValue(
|
||||
makeEventsResponse([
|
||||
{
|
||||
event_id: 'workspace-1',
|
||||
event_type: EventType.API_USAGE_COMPLETED,
|
||||
params: { api_name: 'WorkspaceAPI', model: 'workspace-model' },
|
||||
createdAt: '2024-02-01T10:00:00Z'
|
||||
}
|
||||
])
|
||||
)
|
||||
|
||||
renderWithAutoRefresh()
|
||||
|
||||
mockBillingRouting.shouldUseWorkspaceBilling = true
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('WorkspaceAPI')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
resolveLegacy(
|
||||
makeEventsResponse([
|
||||
{
|
||||
event_id: 'legacy-1',
|
||||
event_type: EventType.API_USAGE_COMPLETED,
|
||||
params: { api_name: 'LegacyAPI', model: 'legacy-model' },
|
||||
createdAt: '2024-01-01T10:00:00Z'
|
||||
}
|
||||
])
|
||||
)
|
||||
|
||||
await flushMicrotasks()
|
||||
|
||||
expect(screen.getByText('WorkspaceAPI')).toBeInTheDocument()
|
||||
expect(screen.queryByText('LegacyAPI')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('runs top-up completion telemetry for a superseded response', async () => {
|
||||
let resolveLegacy!: (value: ReturnType<typeof makeEventsResponse>) => void
|
||||
mockCustomerEventsService.getMyEvents.mockReturnValue(
|
||||
new Promise((resolve) => {
|
||||
resolveLegacy = resolve
|
||||
})
|
||||
)
|
||||
mockWorkspaceApi.getBillingEvents.mockResolvedValue(
|
||||
makeEventsResponse([
|
||||
{
|
||||
event_id: 'workspace-1',
|
||||
event_type: EventType.API_USAGE_COMPLETED,
|
||||
params: { api_name: 'WorkspaceAPI', model: 'workspace-model' },
|
||||
createdAt: '2024-02-01T10:00:00Z'
|
||||
}
|
||||
])
|
||||
)
|
||||
|
||||
renderWithAutoRefresh()
|
||||
|
||||
mockBillingRouting.shouldUseWorkspaceBilling = true
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('WorkspaceAPI')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const legacyResponse = makeEventsResponse([
|
||||
{
|
||||
event_id: 'legacy-1',
|
||||
event_type: EventType.CREDIT_ADDED,
|
||||
params: { amount: 1000 },
|
||||
createdAt: '2024-01-01T10:00:00Z'
|
||||
}
|
||||
])
|
||||
resolveLegacy(legacyResponse)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockTelemetry.checkForCompletedTopup).toHaveBeenCalledWith(
|
||||
legacyResponse.events
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('EventType integration', () => {
|
||||
|
||||
@@ -96,11 +96,11 @@ import Column from 'primevue/column'
|
||||
import DataTable from 'primevue/datatable'
|
||||
import Message from 'primevue/message'
|
||||
import ProgressSpinner from 'primevue/progressspinner'
|
||||
import { computed, ref, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingRouting } from '@/composables/billing/useBillingRouting'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
@@ -109,15 +109,14 @@ import {
|
||||
useCustomerEventsService
|
||||
} from '@/services/customerEventsService'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const events = ref<AuditLog[]>([])
|
||||
const loading = ref(true)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
const customerEventService = useCustomerEventsService()
|
||||
|
||||
const { shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
const { flags } = useFeatureFlags()
|
||||
const useBillingApi = computed(() => isCloud && flags.teamWorkspacesEnabled)
|
||||
|
||||
const pagination = ref({
|
||||
page: 1,
|
||||
@@ -140,12 +139,7 @@ const tooltipContentMap = computed(() => {
|
||||
return map
|
||||
})
|
||||
|
||||
// A billing-route flip can overlap two loads against different backends; only
|
||||
// the latest may mutate state, so a superseded response is discarded.
|
||||
let latestLoadToken = 0
|
||||
|
||||
const loadEvents = async () => {
|
||||
const loadToken = ++latestLoadToken
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
@@ -154,17 +148,10 @@ const loadEvents = async () => {
|
||||
page: pagination.value.page,
|
||||
limit: pagination.value.limit
|
||||
}
|
||||
const response = shouldUseWorkspaceBilling.value
|
||||
const response = useBillingApi.value
|
||||
? await workspaceApi.getBillingEvents(params)
|
||||
: await customerEventService.getMyEvents(params)
|
||||
|
||||
// Completion telemetry must run even when a mid-checkout route flip
|
||||
// supersedes this load, since legacy and workspace backends emit different
|
||||
// top-up events and the winning fetch may not carry the completion yet.
|
||||
useTelemetry()?.checkForCompletedTopup(response?.events)
|
||||
|
||||
if (loadToken !== latestLoadToken) return
|
||||
|
||||
if (response) {
|
||||
if (response.events) {
|
||||
events.value = response.events
|
||||
@@ -178,25 +165,24 @@ const loadEvents = async () => {
|
||||
pagination.value.limit = response.limit
|
||||
}
|
||||
|
||||
if (response.total != null) {
|
||||
if (response.total) {
|
||||
pagination.value.total = response.total
|
||||
}
|
||||
|
||||
if (response.totalPages != null) {
|
||||
if (response.totalPages) {
|
||||
pagination.value.totalPages = response.totalPages
|
||||
}
|
||||
|
||||
// Check if a pending top-up has completed
|
||||
useTelemetry()?.checkForCompletedTopup(response.events)
|
||||
} else {
|
||||
const legacyError = shouldUseWorkspaceBilling.value
|
||||
? null
|
||||
: customerEventService.error.value
|
||||
error.value = legacyError || t('credits.loadEventsError')
|
||||
error.value = customerEventService.error.value || 'Failed to load events'
|
||||
}
|
||||
} catch (err) {
|
||||
if (loadToken !== latestLoadToken) return
|
||||
error.value = t('credits.loadEventsUnknownError')
|
||||
error.value = err instanceof Error ? err.message : 'Unknown error'
|
||||
console.error('Error loading events:', err)
|
||||
} finally {
|
||||
if (loadToken === latestLoadToken) loading.value = false
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,12 +198,6 @@ const refresh = async () => {
|
||||
await loadEvents()
|
||||
}
|
||||
|
||||
watch(shouldUseWorkspaceBilling, () => {
|
||||
refresh().catch((error) => {
|
||||
console.error('Error loading events:', error)
|
||||
})
|
||||
})
|
||||
|
||||
defineExpose({
|
||||
refresh
|
||||
})
|
||||
|
||||
@@ -42,34 +42,22 @@ function withStrictMillisecondParser<T>(run: () => T): T {
|
||||
}
|
||||
|
||||
const mockSubscription = vi.hoisted(() => ({
|
||||
value: null as {
|
||||
endDate: string | null
|
||||
duration?: 'ANNUAL' | 'MONTHLY' | null
|
||||
} | null
|
||||
value: null as { endDate: string | null } | null
|
||||
}))
|
||||
|
||||
const mockCancelSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockFetchStatus = vi.hoisted(() => vi.fn())
|
||||
const mockCloseDialog = vi.hoisted(() => vi.fn())
|
||||
const mockToastAdd = vi.hoisted(() => vi.fn())
|
||||
const mockTier = vi.hoisted(() => ({ value: 'STANDARD' as string | null }))
|
||||
const mockTrackCancellation = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: vi.fn(() => ({
|
||||
cancelSubscription: mockCancelSubscription,
|
||||
fetchStatus: mockFetchStatus,
|
||||
subscription: mockSubscription,
|
||||
tier: mockTier
|
||||
subscription: mockSubscription
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({
|
||||
trackSubscriptionCancellation: mockTrackCancellation
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: vi.fn(() => ({
|
||||
closeDialog: mockCloseDialog
|
||||
@@ -106,95 +94,6 @@ function renderComponent(props: { cancelAt?: string } = {}) {
|
||||
describe('CancelSubscriptionDialogContent', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTier.value = 'STANDARD'
|
||||
})
|
||||
|
||||
describe('cancellation telemetry', () => {
|
||||
it('tracks flow_opened with tier and end date when the dialog mounts', () => {
|
||||
mockSubscription.value = { endDate: '2026-08-01T00:00:00.000Z' }
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(mockTrackCancellation).toHaveBeenCalledWith('flow_opened', {
|
||||
source: 'cancel_plan_menu',
|
||||
current_tier: 'standard',
|
||||
end_date: '2026-08-01T00:00:00.000Z'
|
||||
})
|
||||
})
|
||||
|
||||
it('tracks confirmed before the cancel request and no abandoned on success', async () => {
|
||||
mockSubscription.value = null
|
||||
mockCancelSubscription.mockResolvedValueOnce(undefined)
|
||||
|
||||
const { unmount } = renderComponent()
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /^cancel subscription$/i })
|
||||
)
|
||||
|
||||
await waitFor(() => expect(mockCloseDialog).toHaveBeenCalled())
|
||||
unmount()
|
||||
expect(mockTrackCancellation).toHaveBeenCalledWith(
|
||||
'confirmed',
|
||||
expect.objectContaining({ current_tier: 'standard' })
|
||||
)
|
||||
expect(mockTrackCancellation).not.toHaveBeenCalledWith(
|
||||
'abandoned',
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks confirmed and failed with message-carrying rejection values', async () => {
|
||||
mockSubscription.value = null
|
||||
mockCancelSubscription.mockRejectedValueOnce({ message: 'timed out' })
|
||||
|
||||
renderComponent()
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /^cancel subscription$/i })
|
||||
)
|
||||
|
||||
await waitFor(() =>
|
||||
expect(mockTrackCancellation).toHaveBeenCalledWith(
|
||||
'failed',
|
||||
expect.objectContaining({ error_message: 'timed out' })
|
||||
)
|
||||
)
|
||||
expect(mockTrackCancellation).toHaveBeenCalledWith(
|
||||
'confirmed',
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks abandoned when the user keeps the subscription', async () => {
|
||||
mockSubscription.value = null
|
||||
|
||||
const { unmount } = renderComponent()
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /keep subscription/i })
|
||||
)
|
||||
|
||||
expect(mockCloseDialog).toHaveBeenCalledWith({
|
||||
key: 'cancel-subscription'
|
||||
})
|
||||
unmount()
|
||||
expect(mockTrackCancellation).toHaveBeenCalledWith(
|
||||
'abandoned',
|
||||
expect.objectContaining({ current_tier: 'standard' })
|
||||
)
|
||||
expect(mockCancelSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('tracks abandoned when the dialog is dismissed by the shell', () => {
|
||||
mockSubscription.value = null
|
||||
|
||||
const { unmount } = renderComponent()
|
||||
mockTrackCancellation.mockClear()
|
||||
unmount()
|
||||
|
||||
expect(mockTrackCancellation).toHaveBeenCalledWith(
|
||||
'abandoned',
|
||||
expect.objectContaining({ current_tier: 'standard' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('cancel flow', () => {
|
||||
@@ -239,35 +138,6 @@ describe('CancelSubscriptionDialogContent', () => {
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not track cancellation failure when status refresh fails after cancellation succeeds', async () => {
|
||||
mockSubscription.value = null
|
||||
mockCancelSubscription.mockResolvedValueOnce(undefined)
|
||||
mockFetchStatus.mockRejectedValueOnce(new Error('Refresh failed'))
|
||||
|
||||
const { unmount } = renderComponent()
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /^cancel subscription$/i })
|
||||
)
|
||||
|
||||
await waitFor(() =>
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
)
|
||||
expect(mockCloseDialog).toHaveBeenCalledWith({
|
||||
key: 'cancel-subscription'
|
||||
})
|
||||
expect(
|
||||
mockTrackCancellation.mock.calls.some(([stage]) => stage === 'failed')
|
||||
).toBe(false)
|
||||
|
||||
unmount()
|
||||
expect(mockTrackCancellation).not.toHaveBeenCalledWith(
|
||||
'abandoned',
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('formattedEndDate fallbacks', () => {
|
||||
|
||||
@@ -45,16 +45,13 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { useToast } from 'primevue/usetoast'
|
||||
import { computed, onMounted, onUnmounted, ref } from 'vue'
|
||||
import { computed, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { SubscriptionCancellationMetadata } from '@/platform/telemetry/types'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { parseIsoDateSafe } from '@/utils/dateTimeUtil'
|
||||
import { getErrorMessage } from '@/utils/errorUtil'
|
||||
|
||||
const props = defineProps<{
|
||||
cancelAt?: string
|
||||
@@ -63,41 +60,9 @@ const props = defineProps<{
|
||||
const { t } = useI18n()
|
||||
const dialogStore = useDialogStore()
|
||||
const toast = useToast()
|
||||
const { cancelSubscription, fetchStatus, subscription, tier } =
|
||||
useBillingContext()
|
||||
const telemetry = useTelemetry()
|
||||
const { cancelSubscription, fetchStatus, subscription } = useBillingContext()
|
||||
|
||||
const isLoading = ref(false)
|
||||
const didCancelSucceed = ref(false)
|
||||
|
||||
function cancellationMetadata(): SubscriptionCancellationMetadata {
|
||||
const endDate = props.cancelAt ?? subscription.value?.endDate
|
||||
return {
|
||||
source: 'cancel_plan_menu' as const,
|
||||
current_tier: tier.value?.toLowerCase(),
|
||||
...(subscription.value?.duration
|
||||
? {
|
||||
cycle:
|
||||
subscription.value.duration === 'ANNUAL'
|
||||
? ('yearly' as const)
|
||||
: ('monthly' as const)
|
||||
}
|
||||
: {}),
|
||||
...(endDate ? { end_date: endDate } : {})
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
telemetry?.trackSubscriptionCancellation(
|
||||
'flow_opened',
|
||||
cancellationMetadata()
|
||||
)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
if (didCancelSucceed.value || isLoading.value) return
|
||||
telemetry?.trackSubscriptionCancellation('abandoned', cancellationMetadata())
|
||||
})
|
||||
|
||||
const formattedEndDate = computed(() => {
|
||||
const date = parseIsoDateSafe(props.cancelAt ?? subscription.value?.endDate)
|
||||
@@ -119,37 +84,24 @@ function onClose() {
|
||||
}
|
||||
|
||||
async function onConfirmCancel() {
|
||||
telemetry?.trackSubscriptionCancellation('confirmed', cancellationMetadata())
|
||||
isLoading.value = true
|
||||
try {
|
||||
await cancelSubscription()
|
||||
} catch (error) {
|
||||
const errorMessage = getErrorMessage(error)
|
||||
telemetry?.trackSubscriptionCancellation('failed', {
|
||||
...cancellationMetadata(),
|
||||
error_message: errorMessage ?? String(error)
|
||||
await fetchStatus()
|
||||
dialogStore.closeDialog({ key: 'cancel-subscription' })
|
||||
toast.add({
|
||||
severity: 'success',
|
||||
summary: t('subscription.cancelSuccess'),
|
||||
life: 5000
|
||||
})
|
||||
} catch (error) {
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('subscription.cancelDialog.failed'),
|
||||
detail: errorMessage ?? t('g.unknownError')
|
||||
detail: error instanceof Error ? error.message : t('g.unknownError')
|
||||
})
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
return
|
||||
}
|
||||
|
||||
didCancelSucceed.value = true
|
||||
try {
|
||||
await fetchStatus()
|
||||
} catch {
|
||||
// Cancellation already succeeded; stale local subscription status should not report failure.
|
||||
}
|
||||
dialogStore.closeDialog({ key: 'cancel-subscription' })
|
||||
toast.add({
|
||||
severity: 'success',
|
||||
summary: t('subscription.cancelSuccess'),
|
||||
life: 5000
|
||||
})
|
||||
isLoading.value = false
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -19,7 +19,6 @@ const DEFAULT_BILLING_STATUS: BillingStatusResponse = {
|
||||
|
||||
const {
|
||||
mockTeamWorkspacesEnabled,
|
||||
mockConsolidatedBillingEnabled,
|
||||
mockIsPersonal,
|
||||
mockPlans,
|
||||
mockPurchaseCredits,
|
||||
@@ -27,7 +26,6 @@ const {
|
||||
mockBillingStatus
|
||||
} = vi.hoisted(() => ({
|
||||
mockTeamWorkspacesEnabled: { value: false },
|
||||
mockConsolidatedBillingEnabled: { value: false },
|
||||
mockIsPersonal: { value: true },
|
||||
mockPlans: { value: [] as Plan[] },
|
||||
mockPurchaseCredits: vi.fn(),
|
||||
@@ -59,23 +57,11 @@ vi.mock('@/composables/useFeatureFlags', async () => {
|
||||
teamWorkspacesEnabledRef.value = value
|
||||
}
|
||||
})
|
||||
const consolidatedBillingEnabledRef = ref(
|
||||
mockConsolidatedBillingEnabled.value
|
||||
)
|
||||
Object.defineProperty(mockConsolidatedBillingEnabled, 'value', {
|
||||
get: () => consolidatedBillingEnabledRef.value,
|
||||
set: (value: boolean) => {
|
||||
consolidatedBillingEnabledRef.value = value
|
||||
}
|
||||
})
|
||||
return {
|
||||
useFeatureFlags: () => ({
|
||||
flags: {
|
||||
get teamWorkspacesEnabled() {
|
||||
return mockTeamWorkspacesEnabled.value
|
||||
},
|
||||
get consolidatedBillingEnabled() {
|
||||
return mockConsolidatedBillingEnabled.value
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -165,7 +151,6 @@ describe('useBillingContext', () => {
|
||||
setActivePinia(createPinia())
|
||||
vi.clearAllMocks()
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockConsolidatedBillingEnabled.value = false
|
||||
mockIsPersonal.value = true
|
||||
mockPlans.value = []
|
||||
mockBillingStatus.value = { ...DEFAULT_BILLING_STATUS }
|
||||
@@ -177,27 +162,16 @@ describe('useBillingContext', () => {
|
||||
expect(type.value).toBe('legacy')
|
||||
})
|
||||
|
||||
it('keeps personal on legacy when consolidated billing is disabled', () => {
|
||||
it('selects workspace type for personal when team workspaces are enabled', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockConsolidatedBillingEnabled.value = false
|
||||
mockIsPersonal.value = true
|
||||
|
||||
const { type } = useBillingContext()
|
||||
expect(type.value).toBe('legacy')
|
||||
})
|
||||
|
||||
it('selects workspace type for personal when consolidated billing is enabled', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockConsolidatedBillingEnabled.value = true
|
||||
mockIsPersonal.value = true
|
||||
|
||||
const { type } = useBillingContext()
|
||||
expect(type.value).toBe('workspace')
|
||||
})
|
||||
|
||||
it('selects workspace type for team regardless of consolidated billing', () => {
|
||||
it('selects workspace type for team when team workspaces are enabled', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockConsolidatedBillingEnabled.value = false
|
||||
mockIsPersonal.value = false
|
||||
|
||||
const { type } = useBillingContext()
|
||||
@@ -298,7 +272,6 @@ describe('useBillingContext', () => {
|
||||
expect(workspaceApi.getBillingStatus).not.toHaveBeenCalled()
|
||||
|
||||
// Authenticated remote config resolves the flag on for the same workspace
|
||||
mockConsolidatedBillingEnabled.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
|
||||
await vi.waitFor(() => {
|
||||
@@ -307,27 +280,9 @@ describe('useBillingContext', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('moves a personal workspace to workspace billing when consolidated billing flips on', async () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockConsolidatedBillingEnabled.value = false
|
||||
mockIsPersonal.value = true
|
||||
|
||||
const { type } = useBillingContext()
|
||||
await nextTick()
|
||||
expect(type.value).toBe('legacy')
|
||||
|
||||
mockConsolidatedBillingEnabled.value = true
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(type.value).toBe('workspace')
|
||||
expect(workspaceApi.getBillingStatus).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('subscription mirror to workspace store', () => {
|
||||
it('mirrors subscription for personal workspaces on the consolidated billing flow', async () => {
|
||||
it('mirrors subscription for personal workspaces when team workspaces are enabled', async () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockConsolidatedBillingEnabled.value = true
|
||||
mockIsPersonal.value = true
|
||||
|
||||
const { initialize } = useBillingContext()
|
||||
@@ -339,20 +294,6 @@ describe('useBillingContext', () => {
|
||||
subscriptionPlan: null
|
||||
})
|
||||
})
|
||||
|
||||
it('never clobbers the list-derived store when a subscription is absent', async () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsPersonal.value = false
|
||||
|
||||
const { initialize } = useBillingContext()
|
||||
await initialize()
|
||||
await nextTick()
|
||||
|
||||
expect(mockUpdateActiveWorkspace).not.toHaveBeenCalledWith({
|
||||
isSubscribed: false,
|
||||
subscriptionPlan: null
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getMaxSeats', () => {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { computed, ref, shallowRef, toValue, watch } from 'vue'
|
||||
import { createSharedComposable } from '@vueuse/core'
|
||||
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import {
|
||||
KEY_TO_TIER,
|
||||
getTierFeatures
|
||||
@@ -17,10 +18,10 @@ import type {
|
||||
BalanceInfo,
|
||||
BillingActions,
|
||||
BillingContext,
|
||||
BillingType,
|
||||
BillingState,
|
||||
SubscriptionInfo
|
||||
} from './types'
|
||||
import { useBillingRouting } from './useBillingRouting'
|
||||
import { useLegacyBilling } from './useLegacyBilling'
|
||||
import { useWorkspaceBilling } from '@/platform/workspace/composables/useWorkspaceBilling'
|
||||
|
||||
@@ -34,9 +35,8 @@ const LEGACY_TEAM_PLAN_SLUG_PREFIX = 'team-'
|
||||
* Unified billing context that selects the billing implementation by build/flag.
|
||||
*
|
||||
* - Team workspaces disabled (OSS/Desktop): legacy billing via /customers/*
|
||||
* - Team workspaces enabled: workspace billing via /api/billing/* for team
|
||||
* workspaces, and for personal workspaces once consolidated billing is
|
||||
* enabled; personal workspaces otherwise stay on legacy billing
|
||||
* - Team workspaces enabled: workspace billing via /api/billing/* for both
|
||||
* personal (single-seat workspace) and team workspaces
|
||||
*
|
||||
* The context automatically initializes when the workspace changes and provides
|
||||
* a unified interface for subscription status, balance, and billing actions.
|
||||
@@ -69,7 +69,7 @@ const LEGACY_TEAM_PLAN_SLUG_PREFIX = 'team-'
|
||||
*/
|
||||
function useBillingContextInternal(): BillingContext {
|
||||
const store = useTeamWorkspaceStore()
|
||||
const { type } = useBillingRouting()
|
||||
const { flags } = useFeatureFlags()
|
||||
|
||||
const legacyBillingRef = shallowRef<(BillingState & BillingActions) | null>(
|
||||
null
|
||||
@@ -96,6 +96,16 @@ function useBillingContextInternal(): BillingContext {
|
||||
const isLoading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
/**
|
||||
* Determines which billing type to use, keyed only on the build/flag:
|
||||
* - Team workspaces feature disabled (OSS/Desktop): legacy (/customers)
|
||||
* - Team workspaces feature enabled: workspace (/api/billing), for both
|
||||
* personal (single-seat workspace) and team workspaces
|
||||
*/
|
||||
const type = computed<BillingType>(() =>
|
||||
flags.teamWorkspacesEnabled ? 'workspace' : 'legacy'
|
||||
)
|
||||
|
||||
const activeContext = computed(() =>
|
||||
type.value === 'legacy' ? getLegacyBilling() : getWorkspaceBilling()
|
||||
)
|
||||
@@ -160,12 +170,9 @@ function useBillingContextInternal(): BillingContext {
|
||||
return plan?.max_seats ?? getTierFeatures(tierKey).maxMembers
|
||||
}
|
||||
|
||||
// Sync subscription info to workspace store for display in workspace switcher.
|
||||
// Subscribed means active AND not cancelled, so the delete button enables
|
||||
// after cancellation, even before the period ends. A null subscription means
|
||||
// "not loaded yet" (adapters are discarded on every workspace/type switch);
|
||||
// skip it so the transient reinit gap can't clobber the list-derived baseline
|
||||
// (personal workspaces and subscribed teams already read subscribed there).
|
||||
// Sync subscription info to workspace store for display in workspace switcher
|
||||
// A subscription is considered "subscribed" for workspace purposes if it's active AND not cancelled
|
||||
// This ensures the delete button is enabled after cancellation, even before the period ends
|
||||
watch(
|
||||
subscription,
|
||||
(sub) => {
|
||||
@@ -179,27 +186,24 @@ function useBillingContextInternal(): BillingContext {
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
// Discarding the adapter instances forces a fresh fetch and lets an in-flight
|
||||
// init detect that it was superseded (its captured adapter is no longer the
|
||||
// active one), so a stale response can't resolve into a ready state for the
|
||||
// wrong workspace.
|
||||
function resetBillingState() {
|
||||
legacyBillingRef.value = null
|
||||
workspaceBillingRef.value = null
|
||||
isInitialized.value = false
|
||||
isLoading.value = false
|
||||
error.value = null
|
||||
}
|
||||
|
||||
// type flips when the team-workspaces or consolidated-billing flag resolves
|
||||
// from authenticated config, swapping the active backend. Reset then reinit
|
||||
// on every workspace-id or type change.
|
||||
// type can flip after setup when the team-workspaces flag resolves from
|
||||
// authenticated config, swapping the active backend; a fresh init is needed.
|
||||
// The watch fires only when id or type actually changes, so any fire with a
|
||||
// workspace selected warrants a reinit.
|
||||
watch(
|
||||
[() => store.activeWorkspace?.id, () => type.value],
|
||||
async ([newWorkspaceId]) => {
|
||||
resetBillingState()
|
||||
if (!newWorkspaceId) return
|
||||
if (!newWorkspaceId) {
|
||||
resetBillingState()
|
||||
return
|
||||
}
|
||||
|
||||
isInitialized.value = false
|
||||
try {
|
||||
await initialize()
|
||||
} catch (err) {
|
||||
@@ -212,20 +216,17 @@ function useBillingContextInternal(): BillingContext {
|
||||
async function initialize(): Promise<void> {
|
||||
if (isInitialized.value) return
|
||||
|
||||
const adapter = activeContext.value
|
||||
isLoading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
await adapter.initialize()
|
||||
if (activeContext.value !== adapter) return
|
||||
await activeContext.value.initialize()
|
||||
isInitialized.value = true
|
||||
} catch (err) {
|
||||
if (activeContext.value !== adapter) return
|
||||
error.value =
|
||||
err instanceof Error ? err.message : 'Failed to initialize billing'
|
||||
throw err
|
||||
} finally {
|
||||
if (activeContext.value === adapter) isLoading.value = false
|
||||
isLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useBillingRouting } from './useBillingRouting'
|
||||
|
||||
const { mockFlags, mockActiveWorkspace } = vi.hoisted(() => ({
|
||||
mockFlags: {
|
||||
teamWorkspacesEnabled: false,
|
||||
consolidatedBillingEnabled: false
|
||||
},
|
||||
mockActiveWorkspace: {
|
||||
value: null as { id: string; type: 'personal' | 'team' } | null
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useFeatureFlags', () => ({
|
||||
useFeatureFlags: () => ({ flags: mockFlags })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
useTeamWorkspaceStore: () => ({
|
||||
get activeWorkspace() {
|
||||
return mockActiveWorkspace.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const personal = { id: 'w-personal', type: 'personal' as const }
|
||||
const team = { id: 'w-team', type: 'team' as const }
|
||||
|
||||
describe('useBillingRouting', () => {
|
||||
beforeEach(() => {
|
||||
mockFlags.teamWorkspacesEnabled = false
|
||||
mockFlags.consolidatedBillingEnabled = false
|
||||
mockActiveWorkspace.value = personal
|
||||
})
|
||||
|
||||
it('uses legacy billing when team workspaces are disabled', () => {
|
||||
mockFlags.teamWorkspacesEnabled = false
|
||||
mockActiveWorkspace.value = team
|
||||
|
||||
const { type, shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
|
||||
expect(type.value).toBe('legacy')
|
||||
expect(shouldUseWorkspaceBilling.value).toBe(false)
|
||||
})
|
||||
|
||||
it('keeps personal on legacy when consolidated billing is disabled', () => {
|
||||
mockFlags.teamWorkspacesEnabled = true
|
||||
mockFlags.consolidatedBillingEnabled = false
|
||||
mockActiveWorkspace.value = personal
|
||||
|
||||
const { type } = useBillingRouting()
|
||||
|
||||
expect(type.value).toBe('legacy')
|
||||
})
|
||||
|
||||
it('moves personal to workspace billing when consolidated billing is enabled', () => {
|
||||
mockFlags.teamWorkspacesEnabled = true
|
||||
mockFlags.consolidatedBillingEnabled = true
|
||||
mockActiveWorkspace.value = personal
|
||||
|
||||
const { type, shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
|
||||
expect(type.value).toBe('workspace')
|
||||
expect(shouldUseWorkspaceBilling.value).toBe(true)
|
||||
})
|
||||
|
||||
it('uses workspace billing for team workspaces regardless of consolidated billing', () => {
|
||||
mockFlags.teamWorkspacesEnabled = true
|
||||
mockFlags.consolidatedBillingEnabled = false
|
||||
mockActiveWorkspace.value = team
|
||||
|
||||
const { type, shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
|
||||
expect(type.value).toBe('workspace')
|
||||
expect(shouldUseWorkspaceBilling.value).toBe(true)
|
||||
})
|
||||
|
||||
it('uses workspace billing for team workspaces with consolidated billing enabled', () => {
|
||||
mockFlags.teamWorkspacesEnabled = true
|
||||
mockFlags.consolidatedBillingEnabled = true
|
||||
mockActiveWorkspace.value = team
|
||||
|
||||
const { type, shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
|
||||
expect(type.value).toBe('workspace')
|
||||
expect(shouldUseWorkspaceBilling.value).toBe(true)
|
||||
})
|
||||
|
||||
it('defaults to legacy while the workspace has not loaded', () => {
|
||||
mockFlags.teamWorkspacesEnabled = true
|
||||
mockFlags.consolidatedBillingEnabled = true
|
||||
mockActiveWorkspace.value = null
|
||||
|
||||
const { type } = useBillingRouting()
|
||||
|
||||
expect(type.value).toBe('legacy')
|
||||
})
|
||||
})
|
||||
@@ -1,36 +0,0 @@
|
||||
import { computed } from 'vue'
|
||||
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
import type { BillingType } from './types'
|
||||
|
||||
/**
|
||||
* Selects the billing backend for the active workspace: legacy user-scoped
|
||||
* (`/customers/*`) or workspace-scoped (`/api/billing/*`). Personal workspaces
|
||||
* stay legacy until `consolidatedBillingEnabled`; team workspaces are always
|
||||
* workspace-scoped. The routing matrix is covered in useBillingRouting.test.ts.
|
||||
*/
|
||||
export function useBillingRouting() {
|
||||
const { flags } = useFeatureFlags()
|
||||
const workspaceStore = useTeamWorkspaceStore()
|
||||
|
||||
const type = computed<BillingType>(() => {
|
||||
if (!flags.teamWorkspacesEnabled) return 'legacy'
|
||||
|
||||
// An unloaded workspace has no type yet; stay legacy so bootstrap never
|
||||
// eagerly routes to workspace billing.
|
||||
const workspaceType = workspaceStore.activeWorkspace?.type
|
||||
if (!workspaceType) return 'legacy'
|
||||
|
||||
if (workspaceType === 'personal' && !flags.consolidatedBillingEnabled) {
|
||||
return 'legacy'
|
||||
}
|
||||
|
||||
return 'workspace'
|
||||
})
|
||||
|
||||
const shouldUseWorkspaceBilling = computed(() => type.value === 'workspace')
|
||||
|
||||
return { type, shouldUseWorkspaceBilling }
|
||||
}
|
||||
@@ -6,12 +6,6 @@ import {
|
||||
useFeatureFlags
|
||||
} from '@/composables/useFeatureFlags'
|
||||
import * as distributionTypes from '@/platform/distribution/types'
|
||||
import {
|
||||
cachedConsolidatedBillingEnabled,
|
||||
cachedTeamWorkspacesEnabled,
|
||||
remoteConfig,
|
||||
remoteConfigState
|
||||
} from '@/platform/remoteConfig/remoteConfig'
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
// Mock the API module
|
||||
@@ -225,86 +219,6 @@ describe('useFeatureFlags', () => {
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.teamWorkspacesEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('consolidatedBillingEnabled override bypasses isCloud and isAuthenticatedConfigLoaded guards', () => {
|
||||
vi.mocked(distributionTypes).isCloud = false
|
||||
localStorage.setItem('ff:consolidated_billing_enabled', 'true')
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.consolidatedBillingEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('consolidatedBillingEnabled is false off-cloud even without an override', () => {
|
||||
vi.mocked(distributionTypes).isCloud = false
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.consolidatedBillingEnabled).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('auth-gated flags on cloud', () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(distributionTypes).isCloud = true
|
||||
remoteConfigState.value = 'unloaded'
|
||||
remoteConfig.value = {}
|
||||
cachedTeamWorkspacesEnabled.value = undefined
|
||||
cachedConsolidatedBillingEnabled.value = undefined
|
||||
localStorage.clear()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.mocked(distributionTypes).isCloud = false
|
||||
remoteConfigState.value = 'unloaded'
|
||||
remoteConfig.value = {}
|
||||
cachedTeamWorkspacesEnabled.value = undefined
|
||||
cachedConsolidatedBillingEnabled.value = undefined
|
||||
localStorage.clear()
|
||||
})
|
||||
|
||||
it('returns the cached session value during the auth window', () => {
|
||||
cachedTeamWorkspacesEnabled.value = false
|
||||
cachedConsolidatedBillingEnabled.value = true
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.teamWorkspacesEnabled).toBe(false)
|
||||
expect(flags.consolidatedBillingEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('defaults to false during the auth window when nothing is cached', () => {
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.teamWorkspacesEnabled).toBe(false)
|
||||
expect(flags.consolidatedBillingEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('prefers authenticated remoteConfig over the server feature fallback', () => {
|
||||
remoteConfigState.value = 'authenticated'
|
||||
remoteConfig.value = {
|
||||
team_workspaces_enabled: true,
|
||||
consolidated_billing_enabled: true
|
||||
}
|
||||
vi.mocked(api.getServerFeature).mockReturnValue(false)
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.teamWorkspacesEnabled).toBe(true)
|
||||
expect(flags.consolidatedBillingEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('falls back to api.getServerFeature when authenticated config omits the flag', () => {
|
||||
remoteConfigState.value = 'authenticated'
|
||||
remoteConfig.value = {}
|
||||
vi.mocked(api.getServerFeature).mockImplementation(
|
||||
(path, defaultValue) => {
|
||||
if (path === ServerFeatureFlag.TEAM_WORKSPACES_ENABLED) return true
|
||||
if (path === ServerFeatureFlag.CONSOLIDATED_BILLING_ENABLED)
|
||||
return true
|
||||
return defaultValue
|
||||
}
|
||||
)
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
expect(flags.teamWorkspacesEnabled).toBe(true)
|
||||
expect(flags.consolidatedBillingEnabled).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('signupTurnstileMode', () => {
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { computed, reactive, readonly } from 'vue'
|
||||
import type { Ref } from 'vue'
|
||||
|
||||
import { isCloud, isNightly } from '@/platform/distribution/types'
|
||||
import {
|
||||
cachedConsolidatedBillingEnabled,
|
||||
cachedTeamWorkspacesEnabled,
|
||||
isAuthenticatedConfigLoaded,
|
||||
remoteConfig
|
||||
@@ -32,7 +30,6 @@ export enum ServerFeatureFlag {
|
||||
COMFYHUB_PROFILE_GATE_ENABLED = 'comfyhub_profile_gate_enabled',
|
||||
SHOW_SIGNIN_BUTTON = 'show_signin_button',
|
||||
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth',
|
||||
CONSOLIDATED_BILLING_ENABLED = 'consolidated_billing_enabled',
|
||||
SIGNUP_TURNSTILE = 'signup_turnstile'
|
||||
}
|
||||
|
||||
@@ -49,26 +46,6 @@ function resolveFlag<T>(
|
||||
return remoteConfigValue ?? api.getServerFeature(flagKey, defaultValue)
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a per-user, Cloud-only flag that selects backend behavior. Off the
|
||||
* Cloud build it is always false; during the auth window it falls back to the
|
||||
* cached session value so anonymous bootstrap config cannot route the user to
|
||||
* the wrong backend before authenticated config confirms the flag.
|
||||
*/
|
||||
function resolveAuthGatedFlag(
|
||||
flagKey: string,
|
||||
remoteConfigValue: boolean | undefined,
|
||||
cachedValue: Ref<boolean | undefined>
|
||||
): boolean {
|
||||
const override = getDevOverride<boolean>(flagKey)
|
||||
if (override !== undefined) return override
|
||||
|
||||
if (!isCloud) return false
|
||||
if (!isAuthenticatedConfigLoaded.value) return cachedValue.value ?? false
|
||||
|
||||
return remoteConfigValue ?? api.getServerFeature(flagKey, false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Composable for reactive access to server-side feature flags
|
||||
*/
|
||||
@@ -127,10 +104,18 @@ export function useFeatureFlags() {
|
||||
* and prevents race conditions during initialization.
|
||||
*/
|
||||
get teamWorkspacesEnabled() {
|
||||
return resolveAuthGatedFlag(
|
||||
ServerFeatureFlag.TEAM_WORKSPACES_ENABLED,
|
||||
remoteConfig.value.team_workspaces_enabled,
|
||||
cachedTeamWorkspacesEnabled
|
||||
const override = getDevOverride<boolean>(
|
||||
ServerFeatureFlag.TEAM_WORKSPACES_ENABLED
|
||||
)
|
||||
if (override !== undefined) return override
|
||||
|
||||
if (!isCloud) return false
|
||||
if (!isAuthenticatedConfigLoaded.value)
|
||||
return cachedTeamWorkspacesEnabled.value ?? false
|
||||
|
||||
return (
|
||||
remoteConfig.value.team_workspaces_enabled ??
|
||||
api.getServerFeature(ServerFeatureFlag.TEAM_WORKSPACES_ENABLED, false)
|
||||
)
|
||||
},
|
||||
get userSecretsEnabled() {
|
||||
@@ -190,18 +175,6 @@ export function useFeatureFlags() {
|
||||
false
|
||||
)
|
||||
},
|
||||
/**
|
||||
* Whether personal workspaces use the consolidated (workspace-scoped)
|
||||
* billing flow. While false (default), personal workspaces stay on the
|
||||
* legacy per-user billing flow; team workspaces are unaffected.
|
||||
*/
|
||||
get consolidatedBillingEnabled() {
|
||||
return resolveAuthGatedFlag(
|
||||
ServerFeatureFlag.CONSOLIDATED_BILLING_ENABLED,
|
||||
remoteConfig.value.consolidated_billing_enabled,
|
||||
cachedConsolidatedBillingEnabled
|
||||
)
|
||||
},
|
||||
get signupTurnstileMode() {
|
||||
return resolveFlag(
|
||||
ServerFeatureFlag.SIGNUP_TURNSTILE,
|
||||
|
||||
@@ -2484,8 +2484,6 @@
|
||||
"model": "Model",
|
||||
"added": "Added",
|
||||
"accountInitialized": "Account initialized",
|
||||
"loadEventsError": "Failed to load activity. Please try again.",
|
||||
"loadEventsUnknownError": "Something went wrong while loading activity. Please refresh and try again.",
|
||||
"eventTypes": {
|
||||
"creditAdded": "Credits Added",
|
||||
"accountCreated": "Account Created",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Workspace mode: workspace-aware subscription content (renders its own footer) -->
|
||||
<SubscriptionPanelContentWorkspace v-if="shouldUseWorkspaceBilling" />
|
||||
<SubscriptionPanelContentWorkspace v-if="teamWorkspacesEnabled" />
|
||||
<!-- Legacy mode: user-level subscription content -->
|
||||
<template v-else>
|
||||
<SubscriptionPanelContentLegacy />
|
||||
@@ -29,20 +29,24 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { defineAsyncComponent } from 'vue'
|
||||
import { computed, defineAsyncComponent } from 'vue'
|
||||
|
||||
import CloudBadge from '@/components/topbar/CloudBadge.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useBillingRouting } from '@/composables/billing/useBillingRouting'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import SubscriptionFooterLinks from '@/platform/cloud/subscription/components/SubscriptionFooterLinks.vue'
|
||||
import SubscriptionPanelContentLegacy from '@/platform/cloud/subscription/components/SubscriptionPanelContentLegacy.vue'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
|
||||
const SubscriptionPanelContentWorkspace = defineAsyncComponent(
|
||||
() =>
|
||||
import('@/platform/workspace/components/SubscriptionPanelContentWorkspace.vue')
|
||||
)
|
||||
|
||||
const { shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
const { flags } = useFeatureFlags()
|
||||
const teamWorkspacesEnabled = computed(
|
||||
() => isCloud && flags.teamWorkspacesEnabled
|
||||
)
|
||||
|
||||
const { isActiveSubscription } = useBillingContext()
|
||||
</script>
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import SubscriptionPanelContentLegacy from './SubscriptionPanelContentLegacy.vue'
|
||||
|
||||
const mockAccessBillingPortal = vi.fn()
|
||||
const mockTrackSubscriptionCancellation = vi.fn()
|
||||
const mockShowSubscriptionDialog = vi.fn()
|
||||
const mockHandleRefresh = vi.fn()
|
||||
|
||||
const mockIsActiveSubscription = ref(true)
|
||||
const mockIsCancelled = ref(false)
|
||||
const mockIsFreeTier = ref(false)
|
||||
const mockSubscriptionTier = ref<'STANDARD' | 'CREATOR' | 'PRO' | null>(
|
||||
'STANDARD'
|
||||
)
|
||||
const mockIsYearlySubscription = ref(true)
|
||||
|
||||
vi.mock('@/composables/auth/useAuthActions', () => ({
|
||||
useAuthActions: () => ({
|
||||
accessBillingPortal: mockAccessBillingPortal
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({
|
||||
trackSubscriptionCancellation: mockTrackSubscriptionCancellation
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/cloud/subscription/composables/useSubscription', () => ({
|
||||
useSubscription: () => ({
|
||||
isActiveSubscription: computed(() => mockIsActiveSubscription.value),
|
||||
isCancelled: computed(() => mockIsCancelled.value),
|
||||
isFreeTier: computed(() => mockIsFreeTier.value),
|
||||
formattedRenewalDate: computed(() => '2026-08-01'),
|
||||
formattedEndDate: computed(() => '2026-08-01'),
|
||||
subscriptionTier: computed(() => mockSubscriptionTier.value),
|
||||
subscriptionTierName: computed(() => 'Standard'),
|
||||
isYearlySubscription: computed(() => mockIsYearlySubscription.value)
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/platform/cloud/subscription/composables/useSubscriptionActions',
|
||||
() => ({
|
||||
useSubscriptionActions: () => ({
|
||||
handleRefresh: mockHandleRefresh
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/platform/cloud/subscription/composables/useSubscriptionDialog',
|
||||
() => ({
|
||||
useSubscriptionDialog: () => ({
|
||||
show: mockShowSubscriptionDialog
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: {
|
||||
en: {
|
||||
subscription: {
|
||||
perMonth: '/ month',
|
||||
manageSubscription: 'Manage subscription',
|
||||
upgradePlan: 'Upgrade plan',
|
||||
subscribeNow: 'Subscribe now',
|
||||
yourPlanIncludes: 'Your plan includes',
|
||||
viewMoreDetailsPlans: 'View more details',
|
||||
renewsDate: 'Renews {date}',
|
||||
expiresDate: 'Expires {date}',
|
||||
monthlyCreditsLabel: 'monthly credits',
|
||||
maxDurationLabel: 'max duration',
|
||||
gpuLabel: 'GPU access',
|
||||
addCreditsLabel: 'Add credits',
|
||||
customLoRAsLabel: 'Custom LoRAs',
|
||||
maxDuration: {
|
||||
standard: '30 min'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
function renderComponent() {
|
||||
return render(SubscriptionPanelContentLegacy, {
|
||||
global: {
|
||||
plugins: [i18n],
|
||||
stubs: {
|
||||
CreditsTile: true,
|
||||
SubscribeButton: true,
|
||||
Button: {
|
||||
template: '<button @click="$emit(\'click\')"><slot /></button>',
|
||||
emits: ['click']
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('SubscriptionPanelContentLegacy', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAccessBillingPortal.mockResolvedValue(undefined)
|
||||
mockIsActiveSubscription.value = true
|
||||
mockIsCancelled.value = false
|
||||
mockIsFreeTier.value = false
|
||||
mockSubscriptionTier.value = 'STANDARD'
|
||||
mockIsYearlySubscription.value = true
|
||||
})
|
||||
|
||||
it('tracks cancel intent before opening the billing portal', async () => {
|
||||
renderComponent()
|
||||
|
||||
await userEvent.click(
|
||||
screen.getByRole('button', { name: /manage subscription/i })
|
||||
)
|
||||
|
||||
expect(mockTrackSubscriptionCancellation).toHaveBeenCalledExactlyOnceWith(
|
||||
'flow_opened',
|
||||
{
|
||||
source: 'manage_subscription_button',
|
||||
current_tier: 'standard',
|
||||
cycle: 'yearly'
|
||||
}
|
||||
)
|
||||
expect(mockAccessBillingPortal).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
@@ -36,7 +36,11 @@
|
||||
v-if="isActiveSubscription && !isFreeTier"
|
||||
variant="secondary"
|
||||
class="ml-auto rounded-lg bg-interface-menu-component-surface-selected px-4 py-2 text-sm font-normal text-text-primary"
|
||||
@click="handleManageSubscription"
|
||||
@click="
|
||||
async () => {
|
||||
await authActions.accessBillingPortal()
|
||||
}
|
||||
"
|
||||
>
|
||||
{{ $t('subscription.manageSubscription') }}
|
||||
</Button>
|
||||
@@ -121,7 +125,6 @@ import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import CreditsTile from '@/platform/cloud/subscription/components/CreditsTile.vue'
|
||||
import SubscribeButton from '@/platform/cloud/subscription/components/SubscribeButton.vue'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useSubscriptionActions } from '@/platform/cloud/subscription/composables/useSubscriptionActions'
|
||||
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import {
|
||||
@@ -157,18 +160,6 @@ const tierPrice = computed(() =>
|
||||
getTierPrice(tierKey.value, isYearlySubscription.value)
|
||||
)
|
||||
|
||||
// The portal is the only place a legacy user can cancel (in-app UI already
|
||||
// covers plan changes), so this click is the closest observable cancel-intent
|
||||
// signal on the mainline path.
|
||||
async function handleManageSubscription() {
|
||||
useTelemetry()?.trackSubscriptionCancellation('flow_opened', {
|
||||
source: 'manage_subscription_button',
|
||||
current_tier: subscriptionTier.value?.toLowerCase(),
|
||||
cycle: isYearlySubscription.value ? 'yearly' : 'monthly'
|
||||
})
|
||||
await authActions.accessBillingPortal()
|
||||
}
|
||||
|
||||
const tierBenefits = computed((): TierBenefit[] =>
|
||||
getCommonTierBenefits(tierKey.value, t, n)
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
|
||||
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
|
||||
const mockShouldUseWorkspaceBilling = vi.hoisted(() => ({ value: false }))
|
||||
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
|
||||
const mockCanManageSubscription = vi.hoisted(() => ({ value: true }))
|
||||
@@ -35,10 +35,12 @@ vi.mock('@/services/dialogService', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingRouting', () => ({
|
||||
useBillingRouting: () => ({
|
||||
get shouldUseWorkspaceBilling() {
|
||||
return mockShouldUseWorkspaceBilling
|
||||
vi.mock('@/composables/useFeatureFlags', () => ({
|
||||
useFeatureFlags: () => ({
|
||||
flags: {
|
||||
get teamWorkspacesEnabled() {
|
||||
return mockTeamWorkspacesEnabled.value
|
||||
}
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -86,7 +88,7 @@ describe('useSubscriptionDialog', () => {
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
mockIsFreeTier.value = false
|
||||
mockTier.value = 'FREE'
|
||||
mockShouldUseWorkspaceBilling.value = false
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
mockCanManageSubscription.value = true
|
||||
|
||||
@@ -117,7 +119,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('does not wire onChooseTeam on the unified table (personal subscribes directly)', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
@@ -129,7 +131,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('sizes the unified pricing dialog via the Reka contentClass, not the ignored PrimeVue style', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
@@ -144,7 +146,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('defaults to the personal tab in a personal workspace', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
@@ -155,7 +157,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('opens the team tab when planMode is forced from a personal workspace', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
@@ -165,9 +167,8 @@ describe('useSubscriptionDialog', () => {
|
||||
expect(props.initialPlanMode).toBe('team')
|
||||
})
|
||||
|
||||
it('uses the legacy table (with onChooseTeam) on the legacy billing flow', () => {
|
||||
mockShouldUseWorkspaceBilling.value = false
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
it('uses the legacy table (with onChooseTeam) when team workspaces are disabled', () => {
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable()
|
||||
@@ -177,7 +178,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('routes an existing per-member (legacy) team subscriber to the old team table', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockIsLegacyTeamPlan.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
@@ -195,7 +196,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('keeps a non-legacy (credit-slider) team subscriber on the unified table', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
@@ -219,7 +220,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('tracks modal_opened on the workspace (unified) path too', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
@@ -231,7 +232,7 @@ describe('useSubscriptionDialog', () => {
|
||||
})
|
||||
|
||||
it('does not track modal_opened for the inactive member dialog', () => {
|
||||
mockShouldUseWorkspaceBilling.value = true
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockCanManageSubscription.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
@@ -2,7 +2,7 @@ import { defineAsyncComponent } from 'vue'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useBillingRouting } from '@/composables/billing/useBillingRouting'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
@@ -24,7 +24,7 @@ export interface SubscriptionDialogOptions {
|
||||
}
|
||||
|
||||
export const useSubscriptionDialog = () => {
|
||||
const { shouldUseWorkspaceBilling } = useBillingRouting()
|
||||
const { flags } = useFeatureFlags()
|
||||
const dialogService = useDialogService()
|
||||
const dialogStore = useDialogStore()
|
||||
const workspaceStore = useTeamWorkspaceStore()
|
||||
@@ -57,7 +57,7 @@ export const useSubscriptionDialog = () => {
|
||||
// small read-only "ask your owner to reactivate" modal instead of the
|
||||
// pricing table. Out-of-credits still routes everyone to the credits flow.
|
||||
if (
|
||||
shouldUseWorkspaceBilling.value &&
|
||||
flags.teamWorkspacesEnabled &&
|
||||
!workspaceStore.isInPersonalWorkspace &&
|
||||
!permissions.value.canManageSubscription &&
|
||||
options?.reason !== 'out_of_credits'
|
||||
@@ -95,10 +95,9 @@ export const useSubscriptionDialog = () => {
|
||||
}
|
||||
|
||||
// Jun-5 model: a single unified pricing table (personal/team plan toggle on
|
||||
// one workspace) for workspaces on the consolidated billing flow. Replaces
|
||||
// the old personal-vs-team workspace fork. Personal workspaces still on the
|
||||
// legacy flow (consolidated billing disabled) get the legacy table.
|
||||
if (shouldUseWorkspaceBilling.value) {
|
||||
// one workspace) when team workspaces are enabled. Replaces the old
|
||||
// personal-vs-team workspace fork. Flag-off keeps the legacy table.
|
||||
if (flags.teamWorkspacesEnabled) {
|
||||
// Existing per-member (legacy) team subscribers keep the old tier-based
|
||||
// team table; the unified credit-slider table is for everyone else.
|
||||
// Resolved lazily (not at composable setup): these three composables form
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
cachedConsolidatedBillingEnabled,
|
||||
cachedTeamWorkspacesEnabled,
|
||||
remoteConfig,
|
||||
remoteConfigState
|
||||
@@ -56,14 +55,10 @@ export async function refreshRemoteConfig(
|
||||
window.__CONFIG__ = config
|
||||
remoteConfig.value = config
|
||||
remoteConfigState.value = useAuth ? 'authenticated' : 'anonymous'
|
||||
if (useAuth) {
|
||||
if (useAuth)
|
||||
cachedTeamWorkspacesEnabled.value = Boolean(
|
||||
config.team_workspaces_enabled
|
||||
)
|
||||
cachedConsolidatedBillingEnabled.value = Boolean(
|
||||
config.consolidated_billing_enabled
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -59,8 +59,3 @@ export const cachedTeamWorkspacesEnabled = useStorage<boolean | undefined>(
|
||||
'team_workspaces_enabled' satisfies `${ServerFeatureFlag.TEAM_WORKSPACES_ENABLED}`,
|
||||
undefined
|
||||
)
|
||||
|
||||
export const cachedConsolidatedBillingEnabled = useStorage<boolean | undefined>(
|
||||
'consolidated_billing_enabled' satisfies `${ServerFeatureFlag.CONSOLIDATED_BILLING_ENABLED}`,
|
||||
undefined
|
||||
)
|
||||
|
||||
@@ -111,7 +111,6 @@ export type RemoteConfig = {
|
||||
comfyhub_upload_enabled?: boolean
|
||||
comfyhub_profile_gate_enabled?: boolean
|
||||
unified_cloud_auth?: boolean
|
||||
consolidated_billing_enabled?: boolean
|
||||
sentry_dsn?: string
|
||||
turnstile_sitekey?: string
|
||||
// Raw, unvalidated wire value (a server typo like 'enfroce' is possible).
|
||||
|
||||
@@ -11,49 +11,21 @@ import type { SettingTreeNode } from '@/platform/settings/settingStore'
|
||||
|
||||
import { useSettingUI } from './useSettingUI'
|
||||
|
||||
const env = vi.hoisted(() => {
|
||||
const state = {
|
||||
isCloud: false,
|
||||
isDesktop: false,
|
||||
isLoggedIn: false,
|
||||
teamWorkspacesEnabled: false,
|
||||
userSecretsEnabled: false,
|
||||
isActiveSubscription: false,
|
||||
billingType: 'legacy' as 'legacy' | 'workspace'
|
||||
}
|
||||
const fakeRef = <K extends keyof typeof state>(key: K) => ({
|
||||
get value() {
|
||||
return state[key]
|
||||
}
|
||||
})
|
||||
return { state, fakeRef }
|
||||
})
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({ t: (_: string, fallback: string) => fallback })
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/auth/useCurrentUser', () => ({
|
||||
useCurrentUser: () => ({ isLoggedIn: env.fakeRef('isLoggedIn') })
|
||||
useCurrentUser: () => ({ isLoggedIn: ref(false) })
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isActiveSubscription: env.fakeRef('isActiveSubscription'),
|
||||
type: env.fakeRef('billingType')
|
||||
})
|
||||
useBillingContext: () => ({ isActiveSubscription: ref(false) })
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useFeatureFlags', () => ({
|
||||
useFeatureFlags: () => ({
|
||||
flags: {
|
||||
get teamWorkspacesEnabled() {
|
||||
return env.state.teamWorkspacesEnabled
|
||||
},
|
||||
get userSecretsEnabled() {
|
||||
return env.state.userSecretsEnabled
|
||||
}
|
||||
}
|
||||
flags: { teamWorkspacesEnabled: false, userSecretsEnabled: false }
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -62,12 +34,8 @@ vi.mock('@/composables/useVueFeatureFlags', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
return env.state.isCloud
|
||||
},
|
||||
get isDesktop() {
|
||||
return env.state.isDesktop
|
||||
}
|
||||
isCloud: false,
|
||||
isDesktop: false
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
@@ -109,16 +77,6 @@ describe('useSettingUI', () => {
|
||||
setActivePinia(createTestingPinia())
|
||||
vi.clearAllMocks()
|
||||
|
||||
Object.assign(env.state, {
|
||||
isCloud: false,
|
||||
isDesktop: false,
|
||||
isLoggedIn: false,
|
||||
teamWorkspacesEnabled: false,
|
||||
userSecretsEnabled: false,
|
||||
isActiveSubscription: false,
|
||||
billingType: 'legacy'
|
||||
})
|
||||
|
||||
vi.mocked(useSettingStore).mockReturnValue({
|
||||
settingsById: mockSettings
|
||||
} as ReturnType<typeof useSettingStore>)
|
||||
@@ -179,59 +137,4 @@ describe('useSettingUI', () => {
|
||||
const { defaultCategory } = useSettingUI('about', 'Comfy.Locale')
|
||||
expect(defaultCategory.value.key).toBe('about')
|
||||
})
|
||||
|
||||
describe('legacy billing in the workspace layout', () => {
|
||||
const navKeys = (groups: { items: { id: string }[] }[]) =>
|
||||
groups.flatMap((group) => group.items.map((item) => item.id))
|
||||
|
||||
beforeEach(() => {
|
||||
Object.assign(env.state, {
|
||||
isCloud: true,
|
||||
isLoggedIn: true,
|
||||
teamWorkspacesEnabled: true,
|
||||
isActiveSubscription: true
|
||||
})
|
||||
window.__CONFIG__ = {
|
||||
subscription_required: true
|
||||
} as typeof window.__CONFIG__
|
||||
})
|
||||
|
||||
it('exposes the legacy plan panel when billing is legacy', () => {
|
||||
env.state.billingType = 'legacy'
|
||||
const { defaultCategory, navGroups } = useSettingUI('subscription')
|
||||
|
||||
expect(defaultCategory.value.key).toBe('subscription')
|
||||
expect(navKeys(navGroups.value)).toContain('subscription')
|
||||
expect(navKeys(navGroups.value)).toContain('workspace')
|
||||
})
|
||||
|
||||
it('hides the legacy plan panel when billing is workspace', () => {
|
||||
env.state.billingType = 'workspace'
|
||||
const { navGroups } = useSettingUI()
|
||||
|
||||
expect(navKeys(navGroups.value)).not.toContain('subscription')
|
||||
expect(navKeys(navGroups.value)).toContain('workspace')
|
||||
})
|
||||
|
||||
it('never renders the plan panel in more than one tab', () => {
|
||||
const countSubscription = () => {
|
||||
const { navGroups } = useSettingUI()
|
||||
return navKeys(navGroups.value).filter((id) => id === 'subscription')
|
||||
.length
|
||||
}
|
||||
|
||||
for (const teamWorkspacesEnabled of [true, false]) {
|
||||
for (const billingType of ['legacy', 'workspace'] as const) {
|
||||
for (const isLoggedIn of [true, false]) {
|
||||
Object.assign(env.state, {
|
||||
teamWorkspacesEnabled,
|
||||
billingType,
|
||||
isLoggedIn
|
||||
})
|
||||
expect(countSubscription()).toBeLessThanOrEqual(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -53,7 +53,7 @@ export function useSettingUI(
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
const { shouldRenderVueNodes } = useVueFeatureFlags()
|
||||
const { isActiveSubscription, type: billingType } = useBillingContext()
|
||||
const { isActiveSubscription } = useBillingContext()
|
||||
|
||||
const teamWorkspacesEnabled = computed(
|
||||
() => isCloud && flags.teamWorkspacesEnabled
|
||||
@@ -157,13 +157,6 @@ export function useSettingUI(
|
||||
return isActiveSubscription.value
|
||||
})
|
||||
|
||||
const shouldShowLegacyPlanCreditsPanel = computed(
|
||||
() =>
|
||||
isLoggedIn.value &&
|
||||
billingType.value === 'legacy' &&
|
||||
shouldShowPlanCreditsPanel.value
|
||||
)
|
||||
|
||||
const userPanel: SettingPanelItem = {
|
||||
node: {
|
||||
key: 'user',
|
||||
@@ -308,9 +301,6 @@ export function useSettingUI(
|
||||
label: 'General',
|
||||
children: [
|
||||
translateCategory(userPanel.node),
|
||||
...(shouldShowLegacyPlanCreditsPanel.value && subscriptionPanel
|
||||
? [translateCategory(subscriptionPanel.node)]
|
||||
: []),
|
||||
...coreSettingCategories.value.slice(0, 1).map(translateCategory),
|
||||
...(shouldShowSecretsPanel.value
|
||||
? [translateCategory(secretsPanel.node)]
|
||||
@@ -342,7 +332,9 @@ export function useSettingUI(
|
||||
label: 'Account',
|
||||
children: [
|
||||
userPanel.node,
|
||||
...(shouldShowLegacyPlanCreditsPanel.value && subscriptionPanel
|
||||
...(isLoggedIn.value &&
|
||||
shouldShowPlanCreditsPanel.value &&
|
||||
subscriptionPanel
|
||||
? [subscriptionPanel.node]
|
||||
: []),
|
||||
...(shouldShowSecretsPanel.value ? [secretsPanel.node] : []),
|
||||
|
||||
@@ -78,43 +78,4 @@ describe('TelemetryRegistry', () => {
|
||||
})
|
||||
).not.toThrow()
|
||||
})
|
||||
|
||||
it('dispatches subscription cancellation telemetry to every registered provider', () => {
|
||||
const a: TelemetryProvider = { trackSubscriptionCancellation: vi.fn() }
|
||||
const b: TelemetryProvider = { trackSubscriptionCancellation: vi.fn() }
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const payload = {
|
||||
source: 'cancel_plan_menu' as const,
|
||||
current_tier: 'standard',
|
||||
cycle: 'monthly' as const,
|
||||
end_date: '2026-08-01T00:00:00.000Z'
|
||||
}
|
||||
registry.trackSubscriptionCancellation('flow_opened', payload)
|
||||
|
||||
expect(a.trackSubscriptionCancellation).toHaveBeenCalledExactlyOnceWith(
|
||||
'flow_opened',
|
||||
payload
|
||||
)
|
||||
expect(b.trackSubscriptionCancellation).toHaveBeenCalledExactlyOnceWith(
|
||||
'flow_opened',
|
||||
payload
|
||||
)
|
||||
})
|
||||
|
||||
it('dispatches resubscribe click telemetry to every registered provider', () => {
|
||||
const a: TelemetryProvider = { trackResubscribeClicked: vi.fn() }
|
||||
const b: TelemetryProvider = { trackResubscribeClicked: vi.fn() }
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const payload = { source: 'settings_billing_panel' as const }
|
||||
registry.trackResubscribeClicked(payload)
|
||||
|
||||
expect(a.trackResubscribeClicked).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
expect(b.trackResubscribeClicked).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -19,12 +19,10 @@ import type {
|
||||
SearchQueryMetadata,
|
||||
PageViewMetadata,
|
||||
PageVisibilityMetadata,
|
||||
ResubscribeClickMetadata,
|
||||
RunButtonProperties,
|
||||
SettingChangedMetadata,
|
||||
SharedWorkflowRunMetadata,
|
||||
ShellLayoutMetadata,
|
||||
SubscriptionCancellationMetadata,
|
||||
SubscriptionMetadata,
|
||||
SubscriptionSuccessMetadata,
|
||||
SurveyResponses,
|
||||
@@ -102,19 +100,6 @@ export class TelemetryRegistry implements TelemetryDispatcher {
|
||||
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
|
||||
}
|
||||
|
||||
trackSubscriptionCancellation(
|
||||
event: 'flow_opened' | 'confirmed' | 'abandoned' | 'failed',
|
||||
metadata?: SubscriptionCancellationMetadata
|
||||
): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackSubscriptionCancellation?.(event, metadata)
|
||||
)
|
||||
}
|
||||
|
||||
trackResubscribeClicked(metadata: ResubscribeClickMetadata): void {
|
||||
this.dispatch((provider) => provider.trackResubscribeClicked?.(metadata))
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackAddApiCreditButtonClicked?.(metadata)
|
||||
|
||||
@@ -313,45 +313,6 @@ describe('PostHogTelemetryProvider', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it.for([
|
||||
['flow_opened', TelemetryEvents.SUBSCRIPTION_CANCEL_FLOW_OPENED, {}],
|
||||
['confirmed', TelemetryEvents.SUBSCRIPTION_CANCEL_CONFIRMED, {}],
|
||||
['abandoned', TelemetryEvents.SUBSCRIPTION_CANCEL_ABANDONED, {}],
|
||||
[
|
||||
'failed',
|
||||
TelemetryEvents.SUBSCRIPTION_CANCEL_FAILED,
|
||||
{ error_message: 'timed out' }
|
||||
]
|
||||
] as const)(
|
||||
'captures %s cancellation stage',
|
||||
async ([stage, event, extra]) => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackSubscriptionCancellation(stage, {
|
||||
current_tier: 'standard',
|
||||
...extra
|
||||
})
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(event, {
|
||||
current_tier: 'standard',
|
||||
...extra
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
it('captures resubscribe clicks with their source', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackResubscribeClicked({ source: 'settings_billing_panel' })
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.RESUBSCRIBE_BUTTON_CLICKED,
|
||||
{ source: 'settings_billing_panel' }
|
||||
)
|
||||
})
|
||||
|
||||
it('captures begin_checkout with intent metadata', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
@@ -26,12 +26,10 @@ import type {
|
||||
SearchQueryMetadata,
|
||||
PageViewMetadata,
|
||||
PageVisibilityMetadata,
|
||||
ResubscribeClickMetadata,
|
||||
RunButtonProperties,
|
||||
SettingChangedMetadata,
|
||||
SharedWorkflowRunMetadata,
|
||||
ShellLayoutMetadata,
|
||||
SubscriptionCancellationMetadata,
|
||||
SubscriptionMetadata,
|
||||
SubscriptionSuccessMetadata,
|
||||
SurveyResponses,
|
||||
@@ -49,7 +47,7 @@ import type {
|
||||
WorkflowSavedMetadata,
|
||||
WorkspaceInviteMetadata
|
||||
} from '../../types'
|
||||
import { CANCELLATION_STAGE_EVENTS, TelemetryEvents } from '../../types'
|
||||
import { TelemetryEvents } from '../../types'
|
||||
import { normalizeSurveyResponses } from '../../utils/surveyNormalization'
|
||||
|
||||
const DEFAULT_DISABLED_EVENTS = [
|
||||
@@ -372,17 +370,6 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
|
||||
this.trackEvent(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackSubscriptionCancellation(
|
||||
event: 'flow_opened' | 'confirmed' | 'abandoned' | 'failed',
|
||||
metadata?: SubscriptionCancellationMetadata
|
||||
): void {
|
||||
this.trackEvent(CANCELLATION_STAGE_EVENTS[event], metadata)
|
||||
}
|
||||
|
||||
trackResubscribeClicked(metadata: ResubscribeClickMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.RESUBSCRIBE_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
this.trackEvent(TelemetryEvents.API_CREDIT_TOPUP_BUTTON_PURCHASE_CLICKED, {
|
||||
credit_amount: amount
|
||||
|
||||
@@ -115,36 +115,6 @@ describe('HostTelemetrySink', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards subscription cancellation telemetry to the host bridge', () => {
|
||||
new HostTelemetrySink().trackSubscriptionCancellation('confirmed', {
|
||||
source: 'cancel_plan_menu',
|
||||
current_tier: 'standard',
|
||||
cycle: 'yearly',
|
||||
end_date: '2026-08-01T00:00:00.000Z'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.SUBSCRIPTION_CANCEL_CONFIRMED,
|
||||
{
|
||||
source: 'cancel_plan_menu',
|
||||
current_tier: 'standard',
|
||||
cycle: 'yearly',
|
||||
end_date: '2026-08-01T00:00:00.000Z'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards resubscribe click telemetry to the host bridge', () => {
|
||||
new HostTelemetrySink().trackResubscribeClicked({
|
||||
source: 'pricing_dialog'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.RESUBSCRIBE_BUTTON_CLICKED,
|
||||
{ source: 'pricing_dialog' }
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards add-credit clicks with their source', () => {
|
||||
new HostTelemetrySink().trackAddApiCreditButtonClicked({
|
||||
source: 'avatar_menu'
|
||||
|
||||
@@ -31,8 +31,6 @@ import type {
|
||||
ShareFlowMetadata,
|
||||
ShareLinkOpenedMetadata,
|
||||
SharedWorkflowRunMetadata,
|
||||
ResubscribeClickMetadata,
|
||||
SubscriptionCancellationMetadata,
|
||||
SubscriptionMetadata,
|
||||
SubscriptionSuccessMetadata,
|
||||
SurveyResponses,
|
||||
@@ -48,7 +46,7 @@ import type {
|
||||
WorkflowImportMetadata,
|
||||
WorkflowSavedMetadata
|
||||
} from '../../types'
|
||||
import { CANCELLATION_STAGE_EVENTS, TelemetryEvents } from '../../types'
|
||||
import { TelemetryEvents } from '../../types'
|
||||
import { normalizeSurveyResponses } from '../../utils/surveyNormalization'
|
||||
|
||||
type HostTelemetryProperties = Parameters<
|
||||
@@ -129,17 +127,6 @@ export class HostTelemetrySink implements TelemetryProvider {
|
||||
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackSubscriptionCancellation(
|
||||
event: 'flow_opened' | 'confirmed' | 'abandoned' | 'failed',
|
||||
metadata?: SubscriptionCancellationMetadata
|
||||
): void {
|
||||
this.capture(CANCELLATION_STAGE_EVENTS[event], metadata)
|
||||
}
|
||||
|
||||
trackResubscribeClicked(metadata: ResubscribeClickMetadata): void {
|
||||
this.capture(TelemetryEvents.RESUBSCRIBE_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
@@ -450,27 +450,6 @@ export interface AddCreditsClickMetadata {
|
||||
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
|
||||
}
|
||||
|
||||
export interface SubscriptionCancellationMetadata {
|
||||
current_tier?: string
|
||||
cycle?: BillingCycle
|
||||
/**
|
||||
* `manage_subscription_button` opens the external billing portal, where
|
||||
* cancellation is one of the few possible actions but not the only one —
|
||||
* treat it as probable, not certain, cancel intent.
|
||||
*/
|
||||
source?: 'cancel_plan_menu' | 'manage_subscription_button'
|
||||
/** ISO date the subscription runs until if the cancel goes through. */
|
||||
end_date?: string
|
||||
/** Present only on the `failed` stage. */
|
||||
error_message?: string
|
||||
}
|
||||
|
||||
export interface ResubscribeClickMetadata {
|
||||
source: 'pricing_dialog' | 'settings_billing_panel'
|
||||
/** Why the pricing dialog was opened, when the click came from one. */
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export interface BeginCheckoutMetadata
|
||||
extends Record<string, unknown>, CheckoutAttributionMetadata {
|
||||
user_id: string
|
||||
@@ -535,11 +514,6 @@ export interface TelemetryProvider {
|
||||
metadata?: SubscriptionSuccessMetadata
|
||||
): void
|
||||
trackMonthlySubscriptionCancelled?(): void
|
||||
trackSubscriptionCancellation?(
|
||||
event: 'flow_opened' | 'confirmed' | 'abandoned' | 'failed',
|
||||
metadata?: SubscriptionCancellationMetadata
|
||||
): void
|
||||
trackResubscribeClicked?(metadata: ResubscribeClickMetadata): void
|
||||
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
|
||||
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
|
||||
trackApiCreditTopupSucceeded?(): void
|
||||
@@ -643,11 +617,6 @@ export const TelemetryEvents = {
|
||||
SUBSCRIBE_NOW_BUTTON_CLICKED: 'app:subscribe_now_button_clicked',
|
||||
MONTHLY_SUBSCRIPTION_SUCCEEDED: 'app:monthly_subscription_succeeded',
|
||||
MONTHLY_SUBSCRIPTION_CANCELLED: 'app:monthly_subscription_cancelled',
|
||||
SUBSCRIPTION_CANCEL_FLOW_OPENED: 'app:subscription_cancel_flow_opened',
|
||||
SUBSCRIPTION_CANCEL_CONFIRMED: 'app:subscription_cancel_confirmed',
|
||||
SUBSCRIPTION_CANCEL_ABANDONED: 'app:subscription_cancel_abandoned',
|
||||
SUBSCRIPTION_CANCEL_FAILED: 'app:subscription_cancel_failed',
|
||||
RESUBSCRIBE_BUTTON_CLICKED: 'app:resubscribe_button_clicked',
|
||||
ADD_API_CREDIT_BUTTON_CLICKED: 'app:add_api_credit_button_clicked',
|
||||
API_CREDIT_TOPUP_BUTTON_PURCHASE_CLICKED:
|
||||
'app:api_credit_topup_button_purchase_clicked',
|
||||
@@ -722,13 +691,6 @@ export const TelemetryEvents = {
|
||||
export type TelemetryEventName =
|
||||
(typeof TelemetryEvents)[keyof typeof TelemetryEvents]
|
||||
|
||||
export const CANCELLATION_STAGE_EVENTS = {
|
||||
flow_opened: TelemetryEvents.SUBSCRIPTION_CANCEL_FLOW_OPENED,
|
||||
confirmed: TelemetryEvents.SUBSCRIPTION_CANCEL_CONFIRMED,
|
||||
abandoned: TelemetryEvents.SUBSCRIPTION_CANCEL_ABANDONED,
|
||||
failed: TelemetryEvents.SUBSCRIPTION_CANCEL_FAILED
|
||||
} as const
|
||||
|
||||
export type ExecutionTriggerSource =
|
||||
| 'button'
|
||||
| 'keybinding'
|
||||
|
||||
@@ -3,7 +3,6 @@ import { ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
|
||||
/**
|
||||
* Reactivates a cancelled-but-still-active subscription and surfaces success or
|
||||
@@ -17,9 +16,6 @@ export function useResubscribe() {
|
||||
const isResubscribing = ref(false)
|
||||
|
||||
async function handleResubscribe() {
|
||||
useTelemetry()?.trackResubscribeClicked({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
isResubscribing.value = true
|
||||
try {
|
||||
await resubscribe()
|
||||
|
||||
@@ -123,12 +123,9 @@ vi.mock('primevue/usetoast', () => ({
|
||||
useToast: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
const mockTrackResubscribeClicked = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({
|
||||
trackMonthlySubscriptionSucceeded: vi.fn(),
|
||||
trackResubscribeClicked: mockTrackResubscribeClicked,
|
||||
trackBeginCheckout: mockTrackBeginCheckout
|
||||
})
|
||||
}))
|
||||
@@ -857,7 +854,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
|
||||
describe('handleResubscribe', () => {
|
||||
it('emits close on success', async () => {
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
const checkout = await setup()
|
||||
mockResubscribe.mockResolvedValueOnce({
|
||||
billing_op_id: 'op-4',
|
||||
status: 'active'
|
||||
@@ -869,10 +866,6 @@ describe('useSubscriptionCheckout', () => {
|
||||
|
||||
expect(mockResubscribe).toHaveBeenCalled()
|
||||
expect(emit).toHaveBeenCalledWith('close', true)
|
||||
expect(mockTrackResubscribeClicked).toHaveBeenCalledWith({
|
||||
source: 'pricing_dialog',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('shows error toast on failure', async () => {
|
||||
|
||||
@@ -343,10 +343,6 @@ export function useSubscriptionCheckout(
|
||||
}
|
||||
|
||||
async function handleResubscribe() {
|
||||
telemetry?.trackResubscribeClicked({
|
||||
source: 'pricing_dialog',
|
||||
payment_intent_source: paymentIntentSource
|
||||
})
|
||||
isResubscribing.value = true
|
||||
try {
|
||||
await resubscribe()
|
||||
|
||||
57
src/schemas/nodeDefSchema.test.ts
Normal file
57
src/schemas/nodeDefSchema.test.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
getComboSpecComboOptions,
|
||||
getInputSpecType,
|
||||
isComboInputSpec,
|
||||
isComboInputSpecV1,
|
||||
isComboInputSpecV2,
|
||||
isFloatInputSpec,
|
||||
isIntInputSpec,
|
||||
isMediaUploadComboInput
|
||||
} from './nodeDefSchema'
|
||||
import type {
|
||||
ComboInputSpec,
|
||||
ComboInputSpecV2,
|
||||
InputSpec
|
||||
} from './nodeDefSchema'
|
||||
|
||||
describe('node definition schema helpers', () => {
|
||||
it('identifies input spec variants', () => {
|
||||
const intSpec: InputSpec = ['INT', {}]
|
||||
const floatSpec: InputSpec = ['FLOAT', {}]
|
||||
const comboV1: ComboInputSpec = [['a', 'b'], {}]
|
||||
const comboV2: ComboInputSpecV2 = ['COMBO', { options: ['a', 'b'] }]
|
||||
|
||||
expect(isIntInputSpec(intSpec)).toBe(true)
|
||||
expect(isFloatInputSpec(floatSpec)).toBe(true)
|
||||
expect(isComboInputSpecV1(comboV1)).toBe(true)
|
||||
expect(isComboInputSpecV2(comboV2)).toBe(true)
|
||||
expect(isComboInputSpec(comboV1)).toBe(true)
|
||||
expect(isComboInputSpec(comboV2)).toBe(true)
|
||||
expect(getInputSpecType(comboV1)).toBe('COMBO')
|
||||
expect(getInputSpecType(intSpec)).toBe('INT')
|
||||
})
|
||||
|
||||
it('reads combo options from legacy and v2 combo specs', () => {
|
||||
expect(getComboSpecComboOptions([['a', 1], {}])).toEqual(['a', 1])
|
||||
expect(
|
||||
getComboSpecComboOptions(['COMBO', { options: ['x', 'y'] }])
|
||||
).toEqual(['x', 'y'])
|
||||
expect(getComboSpecComboOptions(['COMBO', {}])).toEqual([])
|
||||
})
|
||||
|
||||
it('detects media upload combo inputs', () => {
|
||||
expect(isMediaUploadComboInput([['a'], { image_upload: true }])).toBe(true)
|
||||
expect(
|
||||
isMediaUploadComboInput(['COMBO', { animated_image_upload: true }])
|
||||
).toBe(true)
|
||||
expect(isMediaUploadComboInput(['COMBO', { video_upload: true }])).toBe(
|
||||
true
|
||||
)
|
||||
expect(isMediaUploadComboInput(['STRING', { image_upload: true }])).toBe(
|
||||
false
|
||||
)
|
||||
expect(isMediaUploadComboInput(['COMBO', undefined])).toBe(false)
|
||||
})
|
||||
})
|
||||
149
src/scripts/api.cloud.test.ts
Normal file
149
src/scripts/api.cloud.test.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
fetchWithUnifiedRemint,
|
||||
shouldRemintCloudRequest
|
||||
} from '@/platform/auth/unified/remintRetry'
|
||||
|
||||
const { mockAuthStore } = vi.hoisted(() => ({
|
||||
mockAuthStore: {
|
||||
isInitialized: true,
|
||||
getAuthHeader: vi.fn(),
|
||||
getAuthToken: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({ isCloud: true }))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: vi.fn(() => mockAuthStore)
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/auth/unified/remintRetry', () => ({
|
||||
fetchWithUnifiedRemint: vi.fn(),
|
||||
shouldRemintCloudRequest: vi.fn()
|
||||
}))
|
||||
|
||||
class FakeWebSocket extends EventTarget {
|
||||
static instances: FakeWebSocket[] = []
|
||||
|
||||
binaryType = ''
|
||||
sent: string[] = []
|
||||
|
||||
constructor(readonly url: string) {
|
||||
super()
|
||||
FakeWebSocket.instances.push(this)
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data)
|
||||
}
|
||||
|
||||
close() {
|
||||
this.dispatchEvent(new Event('close'))
|
||||
}
|
||||
}
|
||||
|
||||
const { ComfyApi } = await import('./api')
|
||||
|
||||
describe('ComfyApi cloud mode', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
FakeWebSocket.instances = []
|
||||
window.name = ''
|
||||
sessionStorage.clear()
|
||||
mockAuthStore.isInitialized = true
|
||||
mockAuthStore.getAuthHeader.mockResolvedValue(null)
|
||||
mockAuthStore.getAuthToken.mockResolvedValue(null)
|
||||
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(false)
|
||||
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
|
||||
new Response(JSON.stringify({ ok: true }), {
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
)
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
|
||||
})
|
||||
|
||||
it('adds cloud auth headers and enables unified retry for authenticated requests', async () => {
|
||||
mockAuthStore.getAuthHeader.mockResolvedValue({
|
||||
Authorization: 'Bearer firebase-token'
|
||||
})
|
||||
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(true)
|
||||
const api = new ComfyApi()
|
||||
api.user = 'cloud-user'
|
||||
|
||||
await api.fetchApi('/queue')
|
||||
|
||||
expect(api.api_base).toBe('')
|
||||
expect(fetchWithUnifiedRemint).toHaveBeenCalledWith(
|
||||
'/api/queue',
|
||||
expect.objectContaining({
|
||||
cache: 'no-cache',
|
||||
headers: {
|
||||
Authorization: 'Bearer firebase-token',
|
||||
'Comfy-User': 'cloud-user'
|
||||
}
|
||||
}),
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('continues cloud fetches when auth header lookup fails', async () => {
|
||||
mockAuthStore.getAuthHeader.mockRejectedValue(new Error('auth unavailable'))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const api = new ComfyApi()
|
||||
|
||||
await api.fetchApi('/history', {
|
||||
headers: [['X-Test', '1']]
|
||||
})
|
||||
|
||||
const [, options, retryOn401] = vi.mocked(fetchWithUnifiedRemint).mock
|
||||
.calls[0]
|
||||
expect(options.headers).toEqual([
|
||||
['X-Test', '1'],
|
||||
['Comfy-User', '']
|
||||
])
|
||||
expect(retryOn401).toBe(false)
|
||||
expect(shouldRemintCloudRequest).not.toHaveBeenCalled()
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Failed to get auth header:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
|
||||
it('adds the cloud auth token to websocket URLs', async () => {
|
||||
mockAuthStore.getAuthToken.mockResolvedValue('socket-token')
|
||||
window.name = 'client-1'
|
||||
const api = new ComfyApi()
|
||||
|
||||
api.init()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(FakeWebSocket.instances).toHaveLength(1)
|
||||
})
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
expect(socket.url).toContain('clientId=client-1')
|
||||
expect(socket.url).toContain('token=socket-token')
|
||||
})
|
||||
|
||||
it('opens a cloud websocket without a token when token lookup fails', async () => {
|
||||
mockAuthStore.getAuthToken.mockRejectedValue(new Error('token unavailable'))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const api = new ComfyApi()
|
||||
|
||||
api.init()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(FakeWebSocket.instances).toHaveLength(1)
|
||||
})
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
expect(socket.url).not.toContain('token=')
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Could not get auth token for WebSocket connection:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
})
|
||||
460
src/scripts/api.core.test.ts
Normal file
460
src/scripts/api.core.test.ts
Normal file
@@ -0,0 +1,460 @@
|
||||
import axios from 'axios'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
|
||||
import type {
|
||||
ComfyApiWorkflow,
|
||||
ComfyWorkflowJSON
|
||||
} from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { PromptResponse } from '@/schemas/apiSchema'
|
||||
import { api as sharedApi, ComfyApi, PromptExecutionError } from '@/scripts/api'
|
||||
import type { NodeExecutionId } from '@/types/nodeIdentification'
|
||||
|
||||
const fetchJobs = vi.hoisted(() => ({
|
||||
fetchHistory: vi.fn(),
|
||||
fetchJobDetail: vi.fn(),
|
||||
fetchQueue: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('axios')
|
||||
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => fetchJobs)
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
function jsonResponse(data: unknown, init: ResponseInit = {}) {
|
||||
return new Response(JSON.stringify(data), {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
...init
|
||||
})
|
||||
}
|
||||
|
||||
function promptData(): {
|
||||
output: ComfyApiWorkflow
|
||||
workflow: ComfyWorkflowJSON
|
||||
} {
|
||||
return {
|
||||
output: fromPartial<ComfyApiWorkflow>({
|
||||
1: {
|
||||
inputs: {},
|
||||
class_type: 'KSampler',
|
||||
_meta: { title: 'KSampler' }
|
||||
}
|
||||
}),
|
||||
workflow: fromPartial<ComfyWorkflowJSON>({
|
||||
version: 0.4,
|
||||
nodes: [],
|
||||
links: []
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function requestBody(fetchApi: ReturnType<typeof vi.spyOn>, call = 0) {
|
||||
const init = fetchApi.mock.calls[call][1]
|
||||
return JSON.parse(String(init?.body)) as Record<string, unknown>
|
||||
}
|
||||
|
||||
describe('PromptExecutionError', () => {
|
||||
it('formats string and node-specific prompt errors', () => {
|
||||
const response = fromPartial<PromptResponse>({
|
||||
error: 'invalid prompt',
|
||||
node_errors: {
|
||||
7: {
|
||||
class_type: 'KSampler',
|
||||
dependent_outputs: [],
|
||||
errors: [{ message: 'bad seed', details: 'seed must be numeric' }]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(new PromptExecutionError(response, 400).toString()).toBe(
|
||||
'invalid prompt\nKSampler:\n - bad seed: seed must be numeric'
|
||||
)
|
||||
})
|
||||
|
||||
it('formats structured prompt errors without node errors', () => {
|
||||
const response = fromPartial<PromptResponse>({
|
||||
error: {
|
||||
type: 'prompt_outputs_failed_validation',
|
||||
message: 'Validation failed',
|
||||
details: 'missing node'
|
||||
}
|
||||
})
|
||||
|
||||
expect(new PromptExecutionError(response).toString()).toBe(
|
||||
'Validation failed: missing node'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi queuePrompt', () => {
|
||||
it('sends front queue requests with auth and execution options', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.clientId = 'client-1'
|
||||
api.authToken = 'auth-token'
|
||||
api.apiKey = 'api-key'
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
|
||||
|
||||
await api.queuePrompt(-1, promptData(), {
|
||||
partialExecutionTargets: ['9:10' as NodeExecutionId],
|
||||
previewMethod: 'auto'
|
||||
})
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith('/prompt', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: fetchApi.mock.calls[0][1]?.body
|
||||
})
|
||||
expect(typeof fetchApi.mock.calls[0][1]?.body).toBe('string')
|
||||
expect(requestBody(fetchApi)).toMatchObject({
|
||||
client_id: 'client-1',
|
||||
front: true,
|
||||
partial_execution_targets: ['9:10'],
|
||||
extra_data: {
|
||||
auth_token_comfy_org: 'auth-token',
|
||||
api_key_comfy_org: 'api-key',
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
preview_method: 'auto'
|
||||
}
|
||||
})
|
||||
expect(requestBody(fetchApi)).not.toHaveProperty('number')
|
||||
})
|
||||
|
||||
it('omits default-only queue options and sets explicit queue numbers', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockImplementation(() =>
|
||||
Promise.resolve(jsonResponse({ prompt_id: 'queued' }))
|
||||
)
|
||||
|
||||
await api.queuePrompt(0, promptData(), { previewMethod: 'default' })
|
||||
await api.queuePrompt(4, promptData())
|
||||
|
||||
expect(requestBody(fetchApi, 0)).toMatchObject({ client_id: '' })
|
||||
expect(requestBody(fetchApi, 0)).not.toHaveProperty('front')
|
||||
expect(requestBody(fetchApi, 0)).not.toHaveProperty('number')
|
||||
expect(
|
||||
requestBody(fetchApi, 0).extra_data as Record<string, unknown>
|
||||
).not.toHaveProperty('preview_method')
|
||||
expect(requestBody(fetchApi, 1)).toMatchObject({ number: 4 })
|
||||
})
|
||||
|
||||
it('throws parsed prompt errors from non-200 responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
jsonResponse(
|
||||
{
|
||||
error: {
|
||||
type: 'server_error',
|
||||
message: 'Server rejected prompt',
|
||||
details: 'bad output'
|
||||
}
|
||||
},
|
||||
{ status: 400, statusText: 'Bad Request' }
|
||||
)
|
||||
)
|
||||
|
||||
await expect(api.queuePrompt(0, promptData())).rejects.toThrow(
|
||||
'Prompt execution failed'
|
||||
)
|
||||
})
|
||||
|
||||
it('wraps non-json prompt errors with status details', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
new Response('backend exploded', {
|
||||
status: 500,
|
||||
statusText: 'Server Error'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(api.queuePrompt(0, promptData())).rejects.toMatchObject({
|
||||
status: 500,
|
||||
response: {
|
||||
error: {
|
||||
message: '500 Server Error',
|
||||
details: 'backend exploded'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi read helpers', () => {
|
||||
it('returns localized templates, default templates, and empty non-json responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: [{ name: 'localized' }]
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'text/html' },
|
||||
data: '<html></html>'
|
||||
})
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
|
||||
{ name: 'localized' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
|
||||
expect(vi.mocked(axios.get).mock.calls[0][0]).toContain(
|
||||
'/templates/index.fr.json'
|
||||
)
|
||||
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
|
||||
'/templates/index.json'
|
||||
)
|
||||
})
|
||||
|
||||
it('falls back from missing localized templates to the default index', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.mocked(axios.get)
|
||||
.mockRejectedValueOnce(new Error('missing locale'))
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: [{ name: 'default' }]
|
||||
})
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('ja')).resolves.toEqual([
|
||||
{ name: 'default' }
|
||||
])
|
||||
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
|
||||
'/templates/index.json'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty model lists for 404s and filters internal folders', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse([
|
||||
{ name: 'checkpoints' },
|
||||
{ name: 'configs' },
|
||||
{ name: 'custom_nodes' }
|
||||
])
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(jsonResponse(['model.safetensors']))
|
||||
|
||||
await expect(api.getModelFolders()).resolves.toEqual([])
|
||||
await expect(api.getModelFolders()).resolves.toEqual([
|
||||
{ name: 'checkpoints' }
|
||||
])
|
||||
await expect(api.getModels('checkpoints')).resolves.toEqual([])
|
||||
await expect(api.getModels('checkpoints')).resolves.toEqual([
|
||||
'model.safetensors'
|
||||
])
|
||||
expect(fetchApi).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
|
||||
it('handles model metadata text responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(new Response(''))
|
||||
.mockResolvedValueOnce(new Response('{"format":"safetensors"}'))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('not json', { status: 200, statusText: 'OK' })
|
||||
)
|
||||
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toBe(null)
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toEqual({ format: 'safetensors' })
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toBe(null)
|
||||
})
|
||||
|
||||
it('gets fuse options only from json responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: { keys: ['name'] }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'text/plain' },
|
||||
data: 'nope'
|
||||
})
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
vi.mocked(axios.get).mockRejectedValueOnce(new Error('missing'))
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi queue and data helpers', () => {
|
||||
it('routes item collection requests to queue or history', async () => {
|
||||
const api = new ComfyApi()
|
||||
const queue = vi.spyOn(api, 'getQueue').mockResolvedValue({
|
||||
Running: [],
|
||||
Pending: []
|
||||
})
|
||||
const historyItem = fromPartial<JobListItem>({
|
||||
id: 'history-1',
|
||||
status: 'completed',
|
||||
create_time: 1,
|
||||
priority: 0
|
||||
})
|
||||
const history = vi.spyOn(api, 'getHistory').mockResolvedValue([historyItem])
|
||||
|
||||
await expect(api.getItems('queue')).resolves.toEqual({
|
||||
Running: [],
|
||||
Pending: []
|
||||
})
|
||||
await expect(api.getItems('history')).resolves.toEqual([historyItem])
|
||||
expect(queue).toHaveBeenCalledOnce()
|
||||
expect(history).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('returns queue fallbacks unless errors are requested', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
fetchJobs.fetchQueue.mockRejectedValue(new Error('network'))
|
||||
|
||||
await expect(api.getQueue()).resolves.toEqual({ Running: [], Pending: [] })
|
||||
await expect(api.getQueue({ throwOnError: true })).rejects.toThrow(
|
||||
'network'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty history when fetchHistory fails', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
fetchJobs.fetchHistory.mockRejectedValue(new Error('history down'))
|
||||
|
||||
await expect(api.getHistory()).resolves.toEqual([])
|
||||
})
|
||||
|
||||
it('posts item mutations with and without request bodies', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi.spyOn(api, 'fetchApi').mockResolvedValue(new Response())
|
||||
|
||||
await api.deleteItem('history', 'job-1')
|
||||
await api.clearItems('queue')
|
||||
await api.interrupt(null)
|
||||
await api.interrupt('running-1')
|
||||
|
||||
expect(fetchApi.mock.calls.map((call) => call[0])).toEqual([
|
||||
'/history',
|
||||
'/queue',
|
||||
'/interrupt',
|
||||
'/interrupt'
|
||||
])
|
||||
expect(fetchApi.mock.calls[0][1]?.body).toBe(
|
||||
JSON.stringify({ delete: ['job-1'] })
|
||||
)
|
||||
expect(fetchApi.mock.calls[1][1]?.body).toBe(
|
||||
JSON.stringify({ clear: true })
|
||||
)
|
||||
expect(fetchApi.mock.calls[2][1]?.body).toBeUndefined()
|
||||
expect(fetchApi.mock.calls[3][1]?.body).toBe(
|
||||
JSON.stringify({ prompt_id: 'running-1' })
|
||||
)
|
||||
})
|
||||
|
||||
it('throws unauthorized settings responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
new Response('', { status: 401, statusText: 'Unauthorized' })
|
||||
)
|
||||
|
||||
await expect(api.getSettings()).rejects.toThrow('Unauthorized')
|
||||
})
|
||||
|
||||
it('stores user data with default and raw-body options', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(new Response('', { status: 200 }))
|
||||
const raw = new Blob(['raw'])
|
||||
|
||||
await api.storeUserData('a/b.json', { ok: true })
|
||||
await api.storeUserData('raw.bin', raw, {
|
||||
overwrite: false,
|
||||
stringify: false,
|
||||
throwOnError: false,
|
||||
full_info: true
|
||||
})
|
||||
|
||||
expect(fetchApi.mock.calls[0][0]).toBe(
|
||||
'/userdata/a%2Fb.json?overwrite=true&full_info=false'
|
||||
)
|
||||
expect(fetchApi.mock.calls[0][1]?.body).toBe(JSON.stringify({ ok: true }))
|
||||
expect(fetchApi.mock.calls[1][0]).toBe(
|
||||
'/userdata/raw.bin?overwrite=false&full_info=true'
|
||||
)
|
||||
expect(fetchApi.mock.calls[1][1]?.body).toBe(raw)
|
||||
})
|
||||
|
||||
it('honors storeUserData throwOnError', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
|
||||
await expect(api.storeUserData('bad.json', {})).rejects.toThrow(
|
||||
"Error storing user data file 'bad.json': 500 Server Error"
|
||||
)
|
||||
await expect(
|
||||
api.storeUserData('bad.json', {}, { throwOnError: false })
|
||||
).resolves.toHaveProperty('status', 500)
|
||||
})
|
||||
|
||||
it('lists full user data info by status', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse([{ path: 'x' }]))
|
||||
|
||||
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([])
|
||||
await expect(api.listUserDataFullInfo('models/')).rejects.toThrow(
|
||||
"Error getting user data list 'models': 500 Server Error"
|
||||
)
|
||||
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([
|
||||
{ path: 'x' }
|
||||
])
|
||||
})
|
||||
|
||||
it('loads global subgraph records and deferred data', async () => {
|
||||
const fetchApi = vi
|
||||
.spyOn(sharedApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({
|
||||
ready: { name: 'Ready', info: { node_pack: 'core' }, data: '{}' },
|
||||
lazy: { name: 'Lazy', info: { node_pack: 'core' } }
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({ data: '{"lazy":true}' }))
|
||||
|
||||
await expect(sharedApi.getGlobalSubgraphs()).resolves.toEqual({})
|
||||
const subgraphs = await sharedApi.getGlobalSubgraphs()
|
||||
|
||||
expect(subgraphs.ready.data).toBe('{}')
|
||||
expect(subgraphs.lazy.data).toBeInstanceOf(Promise)
|
||||
await expect(subgraphs.lazy.data).resolves.toBe('{"lazy":true}')
|
||||
expect(fetchApi).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
829
src/scripts/api.test.ts
Normal file
829
src/scripts/api.test.ts
Normal file
@@ -0,0 +1,829 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import axios from 'axios'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
api as singletonApi,
|
||||
ComfyApi,
|
||||
PromptExecutionError,
|
||||
UnauthorizedError
|
||||
} from '@/scripts/api'
|
||||
import type { ComfyApiWorkflow } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { NodeExecutionId } from '@/types/nodeIdentification'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
|
||||
const { mockToastStore } = vi.hoisted(() => ({
|
||||
mockToastStore: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/auth/unified/remintRetry', () => ({
|
||||
fetchWithUnifiedRemint: vi.fn(),
|
||||
shouldRemintCloudRequest: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => mockToastStore)
|
||||
}))
|
||||
|
||||
vi.mock('axios', () => ({
|
||||
default: {
|
||||
get: vi.fn(),
|
||||
patch: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
class FakeWebSocket extends EventTarget {
|
||||
static instances: FakeWebSocket[] = []
|
||||
|
||||
binaryType = ''
|
||||
sent: string[] = []
|
||||
|
||||
constructor(readonly url: string) {
|
||||
super()
|
||||
FakeWebSocket.instances.push(this)
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data)
|
||||
}
|
||||
|
||||
close() {
|
||||
this.dispatchEvent(new Event('close'))
|
||||
}
|
||||
}
|
||||
|
||||
function jsonResponse(data: unknown, init: ResponseInit = {}) {
|
||||
return new Response(JSON.stringify(data), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
...init
|
||||
})
|
||||
}
|
||||
|
||||
function createWorkflow() {
|
||||
return {
|
||||
last_node_id: 0,
|
||||
last_link_id: 0,
|
||||
nodes: [],
|
||||
links: [],
|
||||
groups: [],
|
||||
config: {},
|
||||
extra: {},
|
||||
version: 0.4
|
||||
}
|
||||
}
|
||||
|
||||
function binaryMessage(type: number, payload: Uint8Array) {
|
||||
const bytes = new Uint8Array(4 + payload.length)
|
||||
new DataView(bytes.buffer).setUint32(0, type)
|
||||
bytes.set(payload, 4)
|
||||
return bytes.buffer
|
||||
}
|
||||
|
||||
function uint32(value: number) {
|
||||
const bytes = new Uint8Array(4)
|
||||
new DataView(bytes.buffer).setUint32(0, value)
|
||||
return bytes
|
||||
}
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
describe('PromptExecutionError', () => {
|
||||
it('formats string, structured, and node-level errors', () => {
|
||||
expect(
|
||||
new PromptExecutionError({
|
||||
error: 'Queue rejected',
|
||||
node_errors: {}
|
||||
}).toString()
|
||||
).toBe('Queue rejected')
|
||||
|
||||
expect(
|
||||
new PromptExecutionError({
|
||||
error: {
|
||||
type: 'invalid_prompt',
|
||||
message: 'Invalid prompt',
|
||||
details: 'missing input'
|
||||
},
|
||||
node_errors: {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
dependent_outputs: ['1'],
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Required input',
|
||||
details: 'source'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}).toString()
|
||||
).toContain('Invalid prompt: missing input\nPreviewAny:')
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
FakeWebSocket.instances = []
|
||||
window.name = ''
|
||||
sessionStorage.clear()
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('builds API, internal, file, and fetch URLs with user headers', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.user = 'reviewer'
|
||||
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
|
||||
jsonResponse({ ok: true })
|
||||
)
|
||||
|
||||
await api.fetchApi('/queue', {
|
||||
headers: new Headers([['X-Test', '1']])
|
||||
})
|
||||
|
||||
expect(api.apiURL('/api/custom')).toBe(`${api.api_base}/api/custom`)
|
||||
expect(api.apiURL('/queue')).toBe(`${api.api_base}/api/queue`)
|
||||
expect(api.internalURL('/logs')).toBe(`${api.api_base}/internal/logs`)
|
||||
expect(api.fileURL('/view')).toBe(`${api.api_base}/view`)
|
||||
const [, options] = vi.mocked(fetchWithUnifiedRemint).mock.calls[0]
|
||||
expect(options.headers).toBeInstanceOf(Headers)
|
||||
expect((options.headers as Headers).get('Comfy-User')).toBe('reviewer')
|
||||
expect((options.headers as Headers).get('X-Test')).toBe('1')
|
||||
})
|
||||
|
||||
it('guards event listeners and still allows removing them', async () => {
|
||||
const api = new ComfyApi()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const listener = vi.fn()
|
||||
const throwingListener = vi.fn(() => {
|
||||
throw new Error('listener failed')
|
||||
})
|
||||
const asyncListener = vi.fn(() => Promise.reject(new Error('async failed')))
|
||||
const objectListener = { handleEvent: vi.fn() }
|
||||
|
||||
api.addEventListener('status', null)
|
||||
api.removeEventListener('status', null)
|
||||
api.addEventListener('status', listener)
|
||||
api.addEventListener('status', throwingListener)
|
||||
api.addEventListener('status', asyncListener)
|
||||
api.addEventListener('status', fromAny(objectListener))
|
||||
|
||||
api.dispatchCustomEvent('status', { exec_info: { queue_remaining: 1 } })
|
||||
await Promise.resolve()
|
||||
|
||||
expect(listener).toHaveBeenCalled()
|
||||
expect(throwingListener).toHaveBeenCalled()
|
||||
expect(asyncListener).toHaveBeenCalled()
|
||||
expect(objectListener.handleEvent).toHaveBeenCalled()
|
||||
expect(warn).toHaveBeenCalledTimes(2)
|
||||
|
||||
api.removeEventListener('status', listener)
|
||||
api.dispatchCustomEvent('status', null)
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('reuses guarded listener wrappers and ignores unknown removals', () => {
|
||||
const api = new ComfyApi()
|
||||
const listener = vi.fn()
|
||||
const neverRegistered = vi.fn()
|
||||
|
||||
api.addEventListener('status', listener)
|
||||
api.addEventListener('status', listener)
|
||||
api.removeEventListener('status', neverRegistered)
|
||||
api.dispatchCustomEvent('status', null)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
expect(neverRegistered).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('supports guarded custom event listeners', () => {
|
||||
const api = new ComfyApi()
|
||||
const listener = vi.fn()
|
||||
|
||||
api.addCustomEventListener('custom-node-event', listener)
|
||||
;(api as EventTarget).dispatchEvent(
|
||||
new CustomEvent('custom-node-event', { detail: { ok: true } })
|
||||
)
|
||||
api.removeCustomEventListener('custom-node-event', listener)
|
||||
;(api as EventTarget).dispatchEvent(
|
||||
new CustomEvent('custom-node-event', { detail: { ok: false } })
|
||||
)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('routes websocket JSON messages and custom registered messages', () => {
|
||||
window.name = 'existing-client'
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const executing = vi.fn()
|
||||
const featureFlags = vi.fn()
|
||||
const custom = vi.fn()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
vi.spyOn(console, 'log').mockImplementation(() => undefined)
|
||||
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('executing', executing)
|
||||
api.addEventListener('feature_flags', featureFlags)
|
||||
api.addCustomEventListener('custom-message', custom)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
socket.dispatchEvent(new Event('open'))
|
||||
|
||||
expect(socket.url).toContain('clientId=existing-client')
|
||||
expect(JSON.parse(socket.sent[0])).toMatchObject({
|
||||
type: 'feature_flags'
|
||||
})
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
data: {
|
||||
sid: 'fresh-client',
|
||||
status: { exec_info: { queue_remaining: 2 } }
|
||||
}
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'executing',
|
||||
data: { node: '12' }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'feature_flags',
|
||||
data: { supports_progress_text_metadata: true }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'custom-message',
|
||||
data: { from: 'extension' }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'unknown-message',
|
||||
data: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'unknown-message',
|
||||
data: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(api.clientId).toBe('fresh-client')
|
||||
expect(window.name).toBe('fresh-client')
|
||||
expect(sessionStorage.getItem('clientId')).toBe('fresh-client')
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { exec_info: { queue_remaining: 2 } }
|
||||
})
|
||||
)
|
||||
expect(executing).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: '12' })
|
||||
)
|
||||
expect(featureFlags).toHaveBeenCalled()
|
||||
expect(api.serverSupportsFeature('supports_progress_text_metadata')).toBe(
|
||||
true
|
||||
)
|
||||
expect(custom).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: { from: 'extension' } })
|
||||
)
|
||||
expect(api.reportedUnknownMessageTypes.has('unknown-message')).toBe(true)
|
||||
expect(warn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('polls status when the initial websocket connection fails', async () => {
|
||||
vi.useFakeTimers()
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({ exec_info: { queue_remaining: 4 } })
|
||||
)
|
||||
.mockRejectedValueOnce(new Error('poll failed'))
|
||||
api.addEventListener('status', status)
|
||||
|
||||
api.init()
|
||||
FakeWebSocket.instances[0].dispatchEvent(new Event('error'))
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { exec_info: { queue_remaining: 4 } }
|
||||
})
|
||||
)
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('emits reconnect lifecycle events after an opened websocket closes', async () => {
|
||||
vi.useFakeTimers()
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const reconnecting = vi.fn()
|
||||
const reconnected = vi.fn()
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('reconnecting', reconnecting)
|
||||
api.addEventListener('reconnected', reconnected)
|
||||
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
socket.dispatchEvent(new Event('open'))
|
||||
socket.close()
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
expect(reconnecting).toHaveBeenCalledOnce()
|
||||
|
||||
await vi.advanceTimersByTimeAsync(300)
|
||||
const reconnectSocket = FakeWebSocket.instances[1]
|
||||
reconnectSocket.dispatchEvent(new Event('open'))
|
||||
|
||||
expect(reconnected).toHaveBeenCalledOnce()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('routes websocket variants without session ids and display-node fallbacks', () => {
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const executing = vi.fn()
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('executing', executing)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
data: { status: undefined }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'executing',
|
||||
data: { node: 'real', display_node: 'display' }
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
expect(executing).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'display' })
|
||||
)
|
||||
})
|
||||
|
||||
it('routes binary preview and progress websocket messages', () => {
|
||||
const api = new ComfyApi()
|
||||
const preview = vi.fn()
|
||||
const previewWithMetadata = vi.fn()
|
||||
const progressText = vi.fn()
|
||||
const encoder = new TextEncoder()
|
||||
api.serverFeatureFlags.value = {
|
||||
supports_progress_text_metadata: true
|
||||
}
|
||||
api.addEventListener('b_preview', preview)
|
||||
api.addEventListener('b_preview_with_metadata', previewWithMetadata)
|
||||
api.addEventListener('progress_text', progressText)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(2), new Uint8Array([1, 2])))
|
||||
})
|
||||
)
|
||||
|
||||
const promptId = encoder.encode('prompt-1')
|
||||
const nodeId = encoder.encode('7')
|
||||
const progressPayload = concatBytes(
|
||||
uint32(promptId.length),
|
||||
promptId,
|
||||
uint32(nodeId.length),
|
||||
nodeId,
|
||||
encoder.encode('loading')
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(3, progressPayload)
|
||||
})
|
||||
)
|
||||
|
||||
const metadata = encoder.encode(
|
||||
JSON.stringify({
|
||||
image_type: 'image/webp',
|
||||
node_id: '7',
|
||||
display_node_id: '7',
|
||||
parent_node_id: '4',
|
||||
real_node_id: '7',
|
||||
prompt_id: 'prompt-1'
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(
|
||||
4,
|
||||
concatBytes(uint32(metadata.length), metadata, new Uint8Array([9]))
|
||||
)
|
||||
})
|
||||
)
|
||||
|
||||
expect(preview).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({ type: 'image/png' })
|
||||
})
|
||||
)
|
||||
expect(progressText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: {
|
||||
nodeId: '7',
|
||||
text: 'loading',
|
||||
prompt_id: 'prompt-1'
|
||||
}
|
||||
})
|
||||
)
|
||||
expect(previewWithMetadata).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({
|
||||
nodeId: '7',
|
||||
parentNodeId: '4',
|
||||
jobId: 'prompt-1'
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('routes binary jpeg/default previews and malformed binary messages defensively', () => {
|
||||
const api = new ComfyApi()
|
||||
const preview = vi.fn()
|
||||
const progressText = vi.fn()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const encoder = new TextEncoder()
|
||||
api.addEventListener('b_preview', preview)
|
||||
api.addEventListener('progress_text', progressText)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(1), new Uint8Array([1])))
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(99), new Uint8Array([2])))
|
||||
})
|
||||
)
|
||||
|
||||
const nodeId = encoder.encode('node')
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(
|
||||
3,
|
||||
concatBytes(uint32(nodeId.length), nodeId, encoder.encode('ready'))
|
||||
)
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(3, new Uint8Array([1]))
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(99, new Uint8Array())
|
||||
})
|
||||
)
|
||||
|
||||
expect(preview).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({ type: 'image/jpeg' })
|
||||
})
|
||||
)
|
||||
expect(progressText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { nodeId: 'node', text: 'ready' }
|
||||
})
|
||||
)
|
||||
expect(warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('serializes prompt queue options and surfaces non-200 errors', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.clientId = 'client-1'
|
||||
api.authToken = 'token-1'
|
||||
api.apiKey = 'key-1'
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ prompt_id: 'queued' }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('backend exploded', {
|
||||
status: 500,
|
||||
statusText: 'Server Error'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
api.queuePrompt(
|
||||
-1,
|
||||
{ output: prompt, workflow: createWorkflow() },
|
||||
{
|
||||
partialExecutionTargets: ['7' as NodeExecutionId],
|
||||
previewMethod: 'latent2rgb'
|
||||
}
|
||||
)
|
||||
).resolves.toEqual({ prompt_id: 'queued' })
|
||||
|
||||
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
|
||||
expect(body).toMatchObject({
|
||||
client_id: 'client-1',
|
||||
prompt,
|
||||
partial_execution_targets: ['7'],
|
||||
front: true,
|
||||
extra_data: {
|
||||
auth_token_comfy_org: 'token-1',
|
||||
api_key_comfy_org: 'key-1',
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
preview_method: 'latent2rgb'
|
||||
}
|
||||
})
|
||||
expect(body.number).toBeUndefined()
|
||||
|
||||
await expect(
|
||||
api.queuePrompt(3, { output: prompt, workflow: createWorkflow() })
|
||||
).rejects.toMatchObject({ status: 500 })
|
||||
})
|
||||
|
||||
it('omits queue position and default preview method for normal queueing', async () => {
|
||||
const api = new ComfyApi()
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
|
||||
|
||||
await api.queuePrompt(
|
||||
0,
|
||||
{ output: prompt, workflow: createWorkflow() },
|
||||
{
|
||||
previewMethod: 'default'
|
||||
}
|
||||
)
|
||||
|
||||
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
|
||||
expect(body.front).toBeUndefined()
|
||||
expect(body.number).toBeUndefined()
|
||||
expect(body.extra_data.preview_method).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles shareable assets, settings, userdata, subgraphs, and memory APIs', async () => {
|
||||
const api = new ComfyApi()
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ assets: [] }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 401, statusText: 'Unauthorized' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 204 }))
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({}, { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
vi.spyOn(singletonApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ data: 'subgraph-data' }))
|
||||
.mockResolvedValueOnce(jsonResponse({ missing: {} }))
|
||||
.mockResolvedValueOnce(jsonResponse({ one: { data: 'inline' } }))
|
||||
|
||||
await expect(
|
||||
api.getShareableAssets(prompt, { owned: false })
|
||||
).resolves.toEqual({ assets: [] })
|
||||
await expect(api.getShareableAssets(prompt)).rejects.toThrow(
|
||||
'Failed to fetch shareable assets'
|
||||
)
|
||||
await expect(api.getSettings()).rejects.toBeInstanceOf(UnauthorizedError)
|
||||
await expect(
|
||||
api.storeUserData('plain.txt', 'raw', {
|
||||
overwrite: false,
|
||||
stringify: false,
|
||||
throwOnError: false,
|
||||
full_info: true
|
||||
})
|
||||
).resolves.toHaveProperty('status', 204)
|
||||
await expect(api.listUserDataFullInfo('/missing/')).resolves.toEqual([])
|
||||
await expect(api.listUserDataFullInfo('/broken/')).rejects.toThrow(
|
||||
"Error getting user data list '/broken'"
|
||||
)
|
||||
await expect(api.getGlobalSubgraphData('one')).resolves.toBe(
|
||||
'subgraph-data'
|
||||
)
|
||||
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
|
||||
"Global subgraph 'missing' returned empty data"
|
||||
)
|
||||
await expect(api.getGlobalSubgraphs()).resolves.toEqual({
|
||||
one: { data: 'inline' }
|
||||
})
|
||||
await api.freeMemory({ freeExecutionCache: true })
|
||||
await api.freeMemory({ freeExecutionCache: false })
|
||||
await api.freeMemory({ freeExecutionCache: true })
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
summary: 'Models and Execution Cache have been cleared.'
|
||||
})
|
||||
)
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ summary: 'Models have been unloaded.' })
|
||||
)
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
summary:
|
||||
'Unloading of models failed. Installed ComfyUI may be an outdated version.'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects non-success global subgraph data responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(singletonApi, 'fetchApi').mockResolvedValueOnce(
|
||||
jsonResponse({}, { status: 404, statusText: 'Not Found' })
|
||||
)
|
||||
|
||||
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
|
||||
"Failed to fetch global subgraph 'missing': 404 Not Found"
|
||||
)
|
||||
})
|
||||
|
||||
it('handles successful settings and userdata helper request shapes', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ theme: 'dark' }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
|
||||
await expect(api.getSettings()).resolves.toEqual({ theme: 'dark' })
|
||||
await expect(api.storeUserData('bad.json', { a: 1 })).rejects.toThrow(
|
||||
"Error storing user data file 'bad.json'"
|
||||
)
|
||||
await api.moveUserData('old/path.json', 'new path.json')
|
||||
await api.deleteUserData('old/path.json')
|
||||
|
||||
expect(fetchApi.mock.calls[2]).toEqual([
|
||||
'/userdata/old%2Fpath.json/move/new%20path.json?overwrite=false',
|
||||
{ method: 'POST' }
|
||||
])
|
||||
expect(fetchApi.mock.calls[3]).toEqual([
|
||||
'/userdata/old%2Fpath.json',
|
||||
{ method: 'DELETE' }
|
||||
])
|
||||
})
|
||||
|
||||
it('handles global subgraph fallbacks and log endpoints', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(singletonApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({
|
||||
missing: {
|
||||
name: 'Missing data',
|
||||
info: { node_pack: 'core' }
|
||||
}
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({ data: 'lazy-data' }))
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({ data: 'log text', headers: {} })
|
||||
.mockResolvedValueOnce({ data: { logs: [] }, headers: {} })
|
||||
.mockRejectedValueOnce(new Error('no folders'))
|
||||
.mockResolvedValueOnce({
|
||||
data: { checkpoints: ['/models'] },
|
||||
headers: {}
|
||||
})
|
||||
vi.mocked(axios.patch).mockResolvedValue({ data: undefined })
|
||||
|
||||
await expect(api.getGlobalSubgraphs()).resolves.toEqual({})
|
||||
const subgraphs = await api.getGlobalSubgraphs()
|
||||
await expect(subgraphs.missing.data).resolves.toBe('lazy-data')
|
||||
await expect(api.getLogs()).resolves.toBe('log text')
|
||||
await expect(api.getRawLogs()).resolves.toEqual({ logs: [] })
|
||||
await api.subscribeLogs(true)
|
||||
await expect(api.getFolderPaths()).resolves.toEqual({})
|
||||
await expect(api.getFolderPaths()).resolves.toEqual({
|
||||
checkpoints: ['/models']
|
||||
})
|
||||
|
||||
expect(axios.patch).toHaveBeenCalledWith(
|
||||
api.internalURL('/logs/subscribe'),
|
||||
{ enabled: true, clientId: undefined }
|
||||
)
|
||||
})
|
||||
|
||||
it('loads localized template indexes and fuse options defensively', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
data: [{ name: 'template' }],
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: '<html></html>',
|
||||
headers: { 'content-type': 'text/html' }
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('missing locale'))
|
||||
.mockResolvedValueOnce({
|
||||
data: [{ name: 'fallback' }],
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('default missing'))
|
||||
.mockResolvedValueOnce({
|
||||
data: { keys: ['name'] },
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: '<html></html>',
|
||||
headers: { 'content-type': 'text/html' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: { ignored: true },
|
||||
headers: {}
|
||||
})
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => undefined)
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
|
||||
{ name: 'template' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
|
||||
await expect(api.getCoreWorkflowTemplates('zh')).resolves.toEqual([
|
||||
{ name: 'fallback' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates('en')).resolves.toEqual([])
|
||||
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,12 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type {
|
||||
CanvasPointerEvent,
|
||||
Subgraph
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ExecutedWsMessage } from '@/schemas/apiSchema'
|
||||
|
||||
const mockAssert = vi.hoisted(() => vi.fn())
|
||||
|
||||
@@ -14,7 +20,7 @@ const mockNodeOutputStore = vi.hoisted(() => ({
|
||||
}))
|
||||
|
||||
const mockSubgraphNavigationStore = vi.hoisted(() => ({
|
||||
exportState: vi.fn(() => []),
|
||||
exportState: vi.fn((): string[] => []),
|
||||
restoreState: vi.fn()
|
||||
}))
|
||||
|
||||
@@ -23,10 +29,24 @@ const mockWorkflowStore = vi.hoisted(() => ({
|
||||
getWorkflowByPath: vi.fn()
|
||||
}))
|
||||
|
||||
const mockExecutionStore = vi.hoisted(() => ({
|
||||
queuedJobs: {} as Record<string, { workflow: { changeTracker: unknown } }>
|
||||
}))
|
||||
|
||||
const mockMaskEditorIsOpened = vi.hoisted(() => vi.fn(() => false))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
constructor: {
|
||||
maskeditor_is_opended: mockMaskEditorIsOpened
|
||||
},
|
||||
graph: {},
|
||||
ui: {
|
||||
autoQueueEnabled: false,
|
||||
autoQueueMode: 'instant'
|
||||
},
|
||||
rootGraph: {
|
||||
subgraphs: new Map(),
|
||||
serialize: vi.fn(() => ({
|
||||
nodes: [],
|
||||
links: [],
|
||||
@@ -39,8 +59,10 @@ vi.mock('@/scripts/app', () => ({
|
||||
}))
|
||||
},
|
||||
canvas: {
|
||||
ds: { scale: 1, offset: [0, 0] }
|
||||
}
|
||||
ds: { scale: 1, offset: [0, 0] },
|
||||
setGraph: vi.fn()
|
||||
},
|
||||
loadGraphData: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -65,6 +87,10 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: vi.fn(() => mockWorkflowStore)
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/executionStore', () => ({
|
||||
useExecutionStore: vi.fn(() => mockExecutionStore)
|
||||
}))
|
||||
|
||||
import { app } from '@/scripts/app'
|
||||
import { api } from '@/scripts/api'
|
||||
import { ChangeTracker } from '@/scripts/changeTracker'
|
||||
@@ -107,13 +133,120 @@ function mockCanvasState(state: ComfyWorkflowJSON) {
|
||||
vi.mocked(app.rootGraph.serialize).mockReturnValue(state as never)
|
||||
}
|
||||
|
||||
type ListenerMap = Record<string, EventListener[]>
|
||||
|
||||
function storeListener(
|
||||
listeners: ListenerMap,
|
||||
type: string,
|
||||
listener: EventListenerOrEventListenerObject
|
||||
) {
|
||||
if (typeof listener === 'function') {
|
||||
listeners[type] ??= []
|
||||
listeners[type].push(listener)
|
||||
}
|
||||
}
|
||||
|
||||
function dispatchStored(listeners: ListenerMap, type: string, event: Event) {
|
||||
for (const listener of listeners[type] ?? []) {
|
||||
listener(event)
|
||||
}
|
||||
}
|
||||
|
||||
async function flushAsyncFrame() {
|
||||
await Promise.resolve()
|
||||
await Promise.resolve()
|
||||
}
|
||||
|
||||
function getApiListener(name: string) {
|
||||
const call = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find(([eventName]) => eventName === name)
|
||||
expect(call).toBeDefined()
|
||||
return call?.[1] as (event: CustomEvent<ExecutedWsMessage>) => void
|
||||
}
|
||||
|
||||
describe('ChangeTracker', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
nodeIdCounter = 0
|
||||
ChangeTracker.isLoadingGraph = false
|
||||
Reflect.set(ChangeTracker, '_checkStateWarned', false)
|
||||
mockWorkflowStore.activeWorkflow = null
|
||||
mockWorkflowStore.getWorkflowByPath.mockReturnValue(null)
|
||||
mockExecutionStore.queuedJobs = {}
|
||||
mockMaskEditorIsOpened.mockReturnValue(false)
|
||||
app.ui.autoQueueEnabled = false
|
||||
app.ui.autoQueueMode = 'instant'
|
||||
vi.mocked(app.canvas.setGraph).mockClear()
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
app.rootGraph.subgraphs.clear()
|
||||
app.canvas.ds.scale = 1
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
})
|
||||
|
||||
describe('reset', () => {
|
||||
it('updates initialState from activeState or an explicit state', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const changed = createState(2)
|
||||
|
||||
tracker.activeState = changed
|
||||
tracker.reset()
|
||||
|
||||
expect(tracker.initialState).toEqual(changed)
|
||||
expect(tracker.initialState).not.toBe(changed)
|
||||
|
||||
const explicit = createState(3)
|
||||
tracker.reset(explicit)
|
||||
|
||||
expect(tracker.activeState).toEqual(explicit)
|
||||
expect(tracker.activeState).not.toBe(explicit)
|
||||
expect(tracker.initialState).toEqual(explicit)
|
||||
})
|
||||
|
||||
it('does not reset while restoring state', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const original = tracker.initialState
|
||||
tracker._restoringState = true
|
||||
|
||||
tracker.reset(createState(2))
|
||||
|
||||
expect(tracker.initialState).toBe(original)
|
||||
})
|
||||
})
|
||||
|
||||
describe('restore', () => {
|
||||
it('restores viewport, outputs, and root graph navigation', () => {
|
||||
const tracker = createTracker()
|
||||
app.canvas.ds.scale = 2
|
||||
app.canvas.ds.offset = [10, 20]
|
||||
mockNodeOutputStore.snapshotOutputs.mockReturnValue({ 1: { images: [] } })
|
||||
mockSubgraphNavigationStore.exportState.mockReturnValue([])
|
||||
|
||||
tracker.store()
|
||||
app.canvas.ds.scale = 1
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
tracker.restore()
|
||||
|
||||
expect(app.canvas.ds.scale).toBe(2)
|
||||
expect(app.canvas.ds.offset).toEqual([10, 20])
|
||||
expect(mockNodeOutputStore.restoreOutputs).toHaveBeenCalledWith({
|
||||
1: { images: [] }
|
||||
})
|
||||
expect(mockSubgraphNavigationStore.restoreState).toHaveBeenCalledWith([])
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
})
|
||||
|
||||
it('restores saved subgraph navigation when the subgraph exists', () => {
|
||||
const tracker = createTracker()
|
||||
const subgraph = { id: 'subgraph-1' } as unknown as Subgraph
|
||||
app.rootGraph.subgraphs.set('subgraph-1', subgraph)
|
||||
mockSubgraphNavigationStore.exportState.mockReturnValue(['subgraph-1'])
|
||||
|
||||
tracker.store()
|
||||
tracker.restore()
|
||||
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
|
||||
})
|
||||
})
|
||||
|
||||
describe('captureCanvasState', () => {
|
||||
@@ -169,9 +302,32 @@ describe('ChangeTracker', () => {
|
||||
expect.stringContaining('captureCanvasState')
|
||||
)
|
||||
})
|
||||
|
||||
it('reports inactive tracker calls only once for the same workflow', () => {
|
||||
const tracker = createTracker()
|
||||
tracker.workflow.path = '/test/dedupe-workflow.json'
|
||||
mockWorkflowStore.activeWorkflow = { changeTracker: {} }
|
||||
|
||||
tracker.captureCanvasState()
|
||||
tracker.captureCanvasState()
|
||||
|
||||
expect(mockAssert).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('state capture', () => {
|
||||
it('sets the active state without pushing undo when none exists yet', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const changed = createState(2)
|
||||
tracker.activeState = undefined as never
|
||||
mockCanvasState(changed)
|
||||
|
||||
tracker.captureCanvasState()
|
||||
|
||||
expect(tracker.activeState).toEqual(changed)
|
||||
expect(tracker.undoQueue).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('pushes to undoQueue, updates activeState, and calls updateModified', () => {
|
||||
const initial = createState(1)
|
||||
const tracker = createTracker(initial)
|
||||
@@ -238,6 +394,19 @@ describe('ChangeTracker', () => {
|
||||
|
||||
expect(tracker.undoQueue).toHaveLength(ChangeTracker.MAX_HISTORY)
|
||||
})
|
||||
|
||||
it('does not capture until the outer change transaction finishes', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
tracker.beforeChange()
|
||||
tracker.beforeChange()
|
||||
mockCanvasState(createState(2))
|
||||
|
||||
tracker.afterChange()
|
||||
expect(app.rootGraph.serialize).not.toHaveBeenCalled()
|
||||
|
||||
tracker.afterChange()
|
||||
expect(app.rootGraph.serialize).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -302,6 +471,105 @@ describe('ChangeTracker', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('updateModified', () => {
|
||||
it('updates workflow modified state when the store can find it', () => {
|
||||
const state = createState(1)
|
||||
const tracker = createTracker(state)
|
||||
const workflow = { isModified: true }
|
||||
mockWorkflowStore.getWorkflowByPath.mockReturnValue(workflow)
|
||||
|
||||
tracker.updateModified()
|
||||
expect(workflow.isModified).toBe(false)
|
||||
|
||||
tracker.activeState = createState(2)
|
||||
tracker.updateModified()
|
||||
expect(workflow.isModified).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('undo and redo', () => {
|
||||
it('restores previous state and moves the current state to the target queue', async () => {
|
||||
const initial = createState(1)
|
||||
const changed = createState(2)
|
||||
const tracker = createTracker(changed)
|
||||
tracker.undoQueue.push(initial)
|
||||
|
||||
await tracker.undo()
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith(
|
||||
initial,
|
||||
false,
|
||||
false,
|
||||
tracker.workflow,
|
||||
{
|
||||
checkForRerouteMigration: false,
|
||||
silentAssetErrors: true
|
||||
}
|
||||
)
|
||||
expect(tracker.activeState).toBe(initial)
|
||||
expect(tracker.redoQueue).toEqual([changed])
|
||||
expect(tracker._restoringState).toBe(false)
|
||||
})
|
||||
|
||||
it('clears restoring state when loading fails', async () => {
|
||||
const tracker = createTracker(createState(2))
|
||||
tracker.undoQueue.push(createState(1))
|
||||
vi.mocked(app.loadGraphData).mockRejectedValueOnce(
|
||||
new Error('load failed')
|
||||
)
|
||||
|
||||
await expect(tracker.undo()).rejects.toThrow('load failed')
|
||||
|
||||
expect(tracker._restoringState).toBe(false)
|
||||
})
|
||||
|
||||
it('does nothing when no previous state exists', async () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
|
||||
await tracker.undo()
|
||||
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('handles keyboard undo and redo shortcuts', async () => {
|
||||
const tracker = createTracker()
|
||||
const undo = vi.spyOn(tracker, 'undo').mockResolvedValue()
|
||||
const redo = vi.spyOn(tracker, 'redo').mockResolvedValue()
|
||||
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', {
|
||||
key: 'z',
|
||||
ctrlKey: true,
|
||||
shiftKey: true
|
||||
})
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', { key: 'y', metaKey: true })
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', {
|
||||
key: 'z',
|
||||
ctrlKey: true,
|
||||
altKey: true
|
||||
})
|
||||
)
|
||||
).resolves.toBeUndefined()
|
||||
|
||||
expect(undo).toHaveBeenCalledOnce()
|
||||
expect(redo).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('checkState (deprecated)', () => {
|
||||
it('delegates to captureCanvasState', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
@@ -312,5 +580,389 @@ describe('ChangeTracker', () => {
|
||||
|
||||
expect(tracker.activeState).toEqual(changed)
|
||||
})
|
||||
|
||||
it('warns only once before delegating', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
|
||||
tracker.checkState()
|
||||
tracker.checkState()
|
||||
|
||||
expect(warn).toHaveBeenCalledOnce()
|
||||
warn.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('bindInput', () => {
|
||||
it('returns false for missing canvas or body elements', () => {
|
||||
expect(ChangeTracker.bindInput(null)).toBe(false)
|
||||
expect(ChangeTracker.bindInput(document.createElement('canvas'))).toBe(
|
||||
false
|
||||
)
|
||||
expect(ChangeTracker.bindInput(document.body)).toBe(false)
|
||||
})
|
||||
|
||||
it('captures state once when an input-like element changes', () => {
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
const input = document.createElement('input')
|
||||
|
||||
expect(ChangeTracker.bindInput(input)).toBe(true)
|
||||
input.dispatchEvent(new Event('change'))
|
||||
input.dispatchEvent(new Event('change'))
|
||||
|
||||
expect(capture).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('binds textarea-like elements that expose an input handler slot', () => {
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
const element = document.createElement('div') as HTMLElement & {
|
||||
oninput: unknown
|
||||
}
|
||||
element.oninput = null
|
||||
|
||||
expect(ChangeTracker.bindInput(element)).toBe(true)
|
||||
element.dispatchEvent(new Event('change'))
|
||||
|
||||
expect(capture).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('init', () => {
|
||||
it('captures changes from registered browser, graph, and API events', async () => {
|
||||
const windowListeners: ListenerMap = {}
|
||||
const documentListeners: ListenerMap = {}
|
||||
const windowAddSpy = vi
|
||||
.spyOn(window, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(windowListeners, type, listener)
|
||||
})
|
||||
const documentAddSpy = vi
|
||||
.spyOn(document, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(documentListeners, type, listener)
|
||||
})
|
||||
const rafSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
|
||||
const processMouseUp = vi.fn(() => true)
|
||||
const prompt = vi.fn()
|
||||
const close = vi.fn(() => true)
|
||||
const originalProcessMouseUp = LGraphCanvas.prototype.processMouseUp
|
||||
const originalPrompt = LGraphCanvas.prototype.prompt
|
||||
const originalClose = LiteGraph.ContextMenu.prototype.close
|
||||
|
||||
LGraphCanvas.prototype.processMouseUp = processMouseUp
|
||||
LGraphCanvas.prototype.prompt = prompt
|
||||
LiteGraph.ContextMenu.prototype.close = close
|
||||
|
||||
try {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
|
||||
dispatchStored(windowListeners, 'mouseup', new MouseEvent('mouseup'))
|
||||
getApiListener('promptQueued')(
|
||||
new CustomEvent('promptQueued', {
|
||||
detail: {} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
getApiListener('graphCleared')(
|
||||
new CustomEvent('graphCleared', {
|
||||
detail: {} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
dispatchStored(
|
||||
documentListeners,
|
||||
'litegraph:canvas',
|
||||
new CustomEvent('litegraph:canvas', {
|
||||
detail: { subType: 'before-change' }
|
||||
})
|
||||
)
|
||||
dispatchStored(
|
||||
documentListeners,
|
||||
'litegraph:canvas',
|
||||
new CustomEvent('litegraph:canvas', {
|
||||
detail: { subType: 'after-change' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(4)
|
||||
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'Control' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
dispatchStored(windowListeners, 'keyup', new KeyboardEvent('keyup'))
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(5)
|
||||
|
||||
const undoRedo = vi.spyOn(tracker, 'undoRedo').mockResolvedValue(true)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(undoRedo).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(5)
|
||||
|
||||
undoRedo.mockResolvedValue(undefined)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'a' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
const input = document.createElement('input')
|
||||
document.body.append(input)
|
||||
input.focus()
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'b' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
input.remove()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
mockMaskEditorIsOpened.mockReturnValue(true)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'c' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
const canvas = {} as LGraphCanvas
|
||||
LGraphCanvas.prototype.processMouseUp.call(
|
||||
canvas,
|
||||
new MouseEvent('mouseup') as CanvasPointerEvent
|
||||
)
|
||||
|
||||
expect(processMouseUp).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(7)
|
||||
|
||||
const promptCallback = vi.fn()
|
||||
LGraphCanvas.prototype.prompt.call(
|
||||
canvas,
|
||||
'title',
|
||||
'value',
|
||||
promptCallback,
|
||||
new MouseEvent('mouseup') as CanvasPointerEvent
|
||||
)
|
||||
const extendedCallback = prompt.mock.calls[0]?.[2] as
|
||||
| ((value: string) => void)
|
||||
| undefined
|
||||
extendedCallback?.('updated')
|
||||
|
||||
expect(promptCallback).toHaveBeenCalledWith('updated')
|
||||
expect(capture).toHaveBeenCalledTimes(8)
|
||||
|
||||
LiteGraph.ContextMenu.prototype.close.call(
|
||||
{} as InstanceType<typeof LiteGraph.ContextMenu>,
|
||||
new MouseEvent('mouseup')
|
||||
)
|
||||
|
||||
expect(close).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(9)
|
||||
} finally {
|
||||
LGraphCanvas.prototype.processMouseUp = originalProcessMouseUp
|
||||
LGraphCanvas.prototype.prompt = originalPrompt
|
||||
LiteGraph.ContextMenu.prototype.close = originalClose
|
||||
windowAddSpy.mockRestore()
|
||||
documentAddSpy.mockRestore()
|
||||
rafSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('ignores repeat keydowns and missing active trackers', async () => {
|
||||
const windowListeners: ListenerMap = {}
|
||||
const windowAddSpy = vi
|
||||
.spyOn(window, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(windowListeners, type, listener)
|
||||
})
|
||||
const documentAddSpy = vi
|
||||
.spyOn(document, 'addEventListener')
|
||||
.mockImplementation(() => undefined)
|
||||
const rafSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
|
||||
try {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'x', repeat: true })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
expect(capture).not.toHaveBeenCalled()
|
||||
|
||||
mockWorkflowStore.activeWorkflow = null
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'x' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
expect(capture).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
windowAddSpy.mockRestore()
|
||||
documentAddSpy.mockRestore()
|
||||
rafSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('stores executed outputs for the workflow that owns the prompt', () => {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const executed = getApiListener('executed')
|
||||
mockExecutionStore.queuedJobs = {
|
||||
promptA: { workflow: { changeTracker: tracker } }
|
||||
}
|
||||
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
output: { images: ['first'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
merge: true,
|
||||
output: { images: ['second'], text: ['caption'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'missing',
|
||||
node: '2',
|
||||
output: { images: ['ignored'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
|
||||
expect(tracker.nodeOutputs).toEqual({
|
||||
1: { images: ['first', 'second'], text: ['caption'] }
|
||||
})
|
||||
})
|
||||
|
||||
it('replaces non-array executed outputs during merge updates', () => {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const executed = getApiListener('executed')
|
||||
mockExecutionStore.queuedJobs = {
|
||||
promptA: { workflow: { changeTracker: tracker } }
|
||||
}
|
||||
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
output: { value: 'old' }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
merge: true,
|
||||
output: { value: 'new' }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
|
||||
expect(tracker.nodeOutputs).toEqual({
|
||||
1: { value: 'new' }
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('graphEqual', () => {
|
||||
it('compares workflow nodes as an unordered set and ignores extra.ds', () => {
|
||||
const first = createState(2)
|
||||
const second = {
|
||||
...createState(),
|
||||
nodes: [...first.nodes].reverse(),
|
||||
links: first.links,
|
||||
groups: first.groups,
|
||||
extra: { ds: { scale: 2 } }
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, first)).toBe(true)
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for non-object values and meaningful graph differences', () => {
|
||||
const first = createState(1)
|
||||
const differentNodes = createState(2)
|
||||
const differentLinks = {
|
||||
...first,
|
||||
links: [[1, 1, 0, 2, 0, 'MODEL']]
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, null as never)).toBe(false)
|
||||
expect(ChangeTracker.graphEqual(first, differentNodes)).toBe(false)
|
||||
expect(ChangeTracker.graphEqual(first, differentLinks)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for extra properties other than viewport state', () => {
|
||||
const first = createState()
|
||||
const second = {
|
||||
...first,
|
||||
extra: { custom: true }
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
|
||||
})
|
||||
|
||||
it.each([
|
||||
'floatingLinks',
|
||||
'reroutes',
|
||||
'groups',
|
||||
'definitions',
|
||||
'subgraphs'
|
||||
] as const)('returns false when %s differs', (key) => {
|
||||
const first = createState()
|
||||
const second = {
|
||||
...first,
|
||||
[key]: [{ id: 1 }]
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
import { describe, expect, test, vi } from 'vitest'
|
||||
|
||||
import type { LGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { ComponentWidgetImpl, DOMWidgetImpl } from '@/scripts/domWidget'
|
||||
import {
|
||||
addWidget,
|
||||
ComponentWidgetImpl,
|
||||
DOMWidgetImpl,
|
||||
isComponentWidget,
|
||||
isDOMWidget
|
||||
} from '@/scripts/domWidget'
|
||||
|
||||
const { registerWidget, unregisterWidget } = vi.hoisted(() => ({
|
||||
registerWidget: vi.fn(),
|
||||
unregisterWidget: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/domWidgetStore', () => ({
|
||||
useDomWidgetStore: () => ({
|
||||
unregisterWidget: vi.fn()
|
||||
registerWidget,
|
||||
unregisterWidget
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -24,13 +37,11 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Set a specific Y position
|
||||
originalWidget.y = 66
|
||||
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved
|
||||
expect(clonedWidget.y).toBe(66)
|
||||
expect(clonedWidget.node).toBe(newNode)
|
||||
expect(clonedWidget.name).toBe('test-widget')
|
||||
@@ -48,13 +59,11 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Set a specific Y position
|
||||
originalWidget.y = 42
|
||||
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved
|
||||
expect(clonedWidget.y).toBe(42)
|
||||
expect(clonedWidget.node).toBe(newNode)
|
||||
expect(clonedWidget.element).toBe(mockElement)
|
||||
@@ -71,11 +80,9 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Don't explicitly set Y (should be 0 by default)
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved (should be 0)
|
||||
expect(clonedWidget.y).toBe(0)
|
||||
})
|
||||
})
|
||||
@@ -96,3 +103,271 @@ describe('BaseDOMWidgetImpl.isVisible', () => {
|
||||
expect(widget.isVisible()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('DOMWidgetImpl', () => {
|
||||
test('identifies DOM and component widgets', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const domWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'dom',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
const componentWidget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'component',
|
||||
component: { template: '<div />' },
|
||||
inputSpec: { name: 'component', type: 'STRING' },
|
||||
options: {}
|
||||
})
|
||||
|
||||
expect(isDOMWidget(domWidget)).toBe(true)
|
||||
expect(isDOMWidget(componentWidget)).toBe(false)
|
||||
expect(isComponentWidget(componentWidget)).toBe(true)
|
||||
expect(isComponentWidget(domWidget)).toBe(false)
|
||||
})
|
||||
|
||||
test('uses option-backed values, callbacks, and margins', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
let value = 'initial'
|
||||
const setValue = vi.fn((next: string) => {
|
||||
value = next
|
||||
})
|
||||
const callback = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getValue: () => value,
|
||||
setValue,
|
||||
margin: 4
|
||||
}
|
||||
})
|
||||
widget.callback = callback
|
||||
|
||||
widget.value = 'next'
|
||||
|
||||
expect(widget.value).toBe('next')
|
||||
expect(widget.margin).toBe(4)
|
||||
expect(setValue).toHaveBeenCalledWith('next')
|
||||
expect(callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
|
||||
test('uses default value and margin when options do not provide them', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
|
||||
expect(widget.value).toBe('')
|
||||
expect(widget.margin).toBe(10)
|
||||
})
|
||||
|
||||
test('draws zoom placeholders and delegates visible draws', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(true)
|
||||
const onDraw = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
hideOnZoom: true,
|
||||
margin: 5,
|
||||
onDraw
|
||||
}
|
||||
})
|
||||
const ctx = {
|
||||
beginPath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillStyle: '#000',
|
||||
rect: vi.fn()
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
|
||||
widget.draw(ctx, node, 100, 10, 40, true)
|
||||
|
||||
expect(ctx.rect).toHaveBeenCalledWith(5, 15, 90, 30)
|
||||
expect(ctx.fill).toHaveBeenCalledOnce()
|
||||
expect(ctx.fillStyle).toBe('#000')
|
||||
expect(onDraw).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
|
||||
test('skips placeholder drawing when hidden', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(false)
|
||||
const onDraw = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
hideOnZoom: true,
|
||||
onDraw
|
||||
}
|
||||
})
|
||||
const ctx = {
|
||||
beginPath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillStyle: '#000',
|
||||
rect: vi.fn()
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
|
||||
widget.draw(ctx, node, 100, 10, 40, true)
|
||||
|
||||
expect(ctx.rect).not.toHaveBeenCalled()
|
||||
expect(onDraw).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
|
||||
test('computes hidden, option, percent, and fallback layout sizes', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
node.size = [100, 200]
|
||||
const hiddenWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'hidden',
|
||||
type: 'hidden',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
const optionWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'option',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getMinHeight: () => 11,
|
||||
getMaxHeight: () => 88,
|
||||
getHeight: () => 44
|
||||
}
|
||||
})
|
||||
const percentWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'percent',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getMinHeight: () => 10,
|
||||
getMaxHeight: () => 60,
|
||||
getHeight: () => '25%'
|
||||
}
|
||||
})
|
||||
const fallbackWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'fallback',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getHeight: () => 40
|
||||
}
|
||||
})
|
||||
|
||||
expect(hiddenWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 0,
|
||||
maxHeight: 0,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(optionWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 11,
|
||||
maxHeight: 88,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(percentWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 10,
|
||||
maxHeight: 60,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(fallbackWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 40,
|
||||
maxHeight: undefined,
|
||||
minWidth: 0
|
||||
})
|
||||
})
|
||||
|
||||
test('registers widgets immediately and through node lifecycle callbacks', () => {
|
||||
registerWidget.mockClear()
|
||||
unregisterWidget.mockClear()
|
||||
const node = new LGraphNode('test-node')
|
||||
node.graph = {} as LGraph
|
||||
const beforeResize = vi.fn()
|
||||
const afterResize = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
beforeResize,
|
||||
afterResize
|
||||
}
|
||||
})
|
||||
vi.spyOn(node, 'addCustomWidget')
|
||||
|
||||
addWidget(node, widget)
|
||||
node.onAdded?.(node.graph)
|
||||
node.onResize?.([0, 0])
|
||||
node.onRemoved?.()
|
||||
|
||||
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
|
||||
expect(registerWidget).toHaveBeenCalledWith(widget)
|
||||
expect(registerWidget).toHaveBeenCalledTimes(2)
|
||||
expect(beforeResize).toHaveBeenCalledWith(node)
|
||||
expect(afterResize).toHaveBeenCalledWith(node)
|
||||
expect(unregisterWidget).toHaveBeenCalledWith(widget.id)
|
||||
})
|
||||
|
||||
test('computes component layout and serializes raw values', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const value = { nested: true }
|
||||
const widget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'component',
|
||||
component: { template: '<div />' },
|
||||
inputSpec: { name: 'component', type: 'STRING' },
|
||||
options: {
|
||||
getValue: () => value,
|
||||
getMinHeight: () => 12,
|
||||
getMaxHeight: () => 48
|
||||
}
|
||||
})
|
||||
|
||||
expect(widget.computeLayoutSize()).toEqual({
|
||||
minHeight: 12,
|
||||
maxHeight: 48,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(widget.serializeValue()).toEqual({ nested: true })
|
||||
})
|
||||
|
||||
test('adds DOM widgets through LGraphNode prototype helper', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const element = document.createElement('textarea')
|
||||
let value = 'initial'
|
||||
const setValue = vi.fn((next: string) => {
|
||||
value = next
|
||||
})
|
||||
vi.spyOn(node, 'addCustomWidget')
|
||||
|
||||
const widget = node.addDOMWidget('text', 'textarea', element, {
|
||||
getValue: () => value,
|
||||
setValue
|
||||
})
|
||||
const callback = vi.fn()
|
||||
widget.callback = callback
|
||||
widget.value = 'next'
|
||||
|
||||
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
|
||||
expect(widget.element).toBe(element)
|
||||
expect(widget.options.hideOnZoom).toBe(true)
|
||||
expect(widget.value).toBe('next')
|
||||
expect(setValue).toHaveBeenCalledWith('next')
|
||||
expect(callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
})
|
||||
|
||||
106
src/scripts/errorNodeWidgets.test.ts
Normal file
106
src/scripts/errorNodeWidgets.test.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
interface TestWidgetOptions {
|
||||
name: string
|
||||
type: string
|
||||
default?: unknown
|
||||
multiline?: boolean
|
||||
}
|
||||
|
||||
function createWidgetFactory() {
|
||||
return (node: LGraphNode, options: TestWidgetOptions): IBaseWidget => {
|
||||
const widget: IBaseWidget = fromAny({
|
||||
name: options.name,
|
||||
type: options.type,
|
||||
value: options.default,
|
||||
options
|
||||
})
|
||||
node.widgets = [...(node.widgets ?? []), widget]
|
||||
return widget
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
|
||||
() => ({
|
||||
useStringWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
|
||||
() => ({
|
||||
useFloatWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
|
||||
() => ({
|
||||
useBooleanWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
import './errorNodeWidgets'
|
||||
|
||||
describe('errorNodeWidgets', () => {
|
||||
it('restores widgets from serialized values on error nodes', () => {
|
||||
const node = new LGraphNode('BrokenNode')
|
||||
const longText = 'serialized value with more than twenty chars'
|
||||
node.has_errors = true
|
||||
|
||||
node.onConfigure?.(
|
||||
fromAny({
|
||||
widgets_values: {
|
||||
length: 5,
|
||||
0: 'short text',
|
||||
1: longText,
|
||||
2: 12,
|
||||
3: true,
|
||||
4: { nested: 'value' }
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.widgets).toHaveLength(5)
|
||||
expect(node.widgets?.map((widget) => widget.name)).toEqual([
|
||||
'UNKNOWN',
|
||||
'UNKNOWN_1',
|
||||
'UNKNOWN_2',
|
||||
'UNKNOWN_3',
|
||||
'UNKNOWN_4'
|
||||
])
|
||||
expect(node.widgets?.map((widget) => widget.label)).toEqual([
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN'
|
||||
])
|
||||
expect(node.widgets?.map((widget) => widget.value)).toEqual([
|
||||
'short text',
|
||||
longText,
|
||||
12,
|
||||
true,
|
||||
'{"nested":"value"}'
|
||||
])
|
||||
expect(node.serialize_widgets).toBe(true)
|
||||
})
|
||||
|
||||
it('leaves normal nodes unchanged', () => {
|
||||
const node = new LGraphNode('HealthyNode')
|
||||
|
||||
node.onConfigure?.(
|
||||
fromPartial({
|
||||
widgets_values: ['ignored']
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.widgets).toBeUndefined()
|
||||
expect(node.serialize_widgets).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromAvifFile } from './avif'
|
||||
|
||||
@@ -83,6 +84,11 @@ describe('AVIF metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromAvifFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromAvifFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -139,7 +145,8 @@ const buildInfeBox = (
|
||||
itemType: string,
|
||||
version = 2
|
||||
): Uint8Array => {
|
||||
const bodySize = 4 + 2 + 2 + 4 + 1 + 1
|
||||
const itemIdSize = version === 2 ? 2 : version >= 3 ? 4 : 0
|
||||
const bodySize = 4 + itemIdSize + (version >= 2 ? 2 + 4 + 1 + 1 : 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
@@ -147,22 +154,36 @@ const buildInfeBox = (
|
||||
buf.set(new TextEncoder().encode('infe'), 4)
|
||||
buf[8] = version
|
||||
if (version >= 2) {
|
||||
setU16BE(dv, 12, itemId)
|
||||
setU16BE(dv, 14, 0)
|
||||
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), 16)
|
||||
let p = 12
|
||||
if (version === 2) {
|
||||
setU16BE(dv, p, itemId)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, itemId)
|
||||
p += 4
|
||||
}
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), p)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
|
||||
const bodySize = 4 + 2 + infeBoxes.reduce((s, b) => s + b.length, 0)
|
||||
const buildIinfBox = (infeBoxes: Uint8Array[], version = 0): Uint8Array => {
|
||||
const countSize = version === 0 ? 2 : 4
|
||||
const bodySize = 4 + countSize + infeBoxes.reduce((s, b) => s + b.length, 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
setU32BE(dv, 0, totalSize)
|
||||
buf.set(new TextEncoder().encode('iinf'), 4)
|
||||
setU16BE(dv, 12, infeBoxes.length)
|
||||
let off = 14
|
||||
buf[8] = version
|
||||
if (version === 0) {
|
||||
setU16BE(dv, 12, infeBoxes.length)
|
||||
} else {
|
||||
setU32BE(dv, 12, infeBoxes.length)
|
||||
}
|
||||
let off = 8 + 4 + countSize
|
||||
for (const ib of infeBoxes) {
|
||||
buf.set(ib, off)
|
||||
off += ib.length
|
||||
@@ -170,31 +191,91 @@ const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
|
||||
return buf
|
||||
}
|
||||
|
||||
interface IlocItem {
|
||||
itemId: number
|
||||
extentOffset: number
|
||||
extentLength: number
|
||||
extents?: Array<{ extentOffset: number; extentLength: number }>
|
||||
}
|
||||
|
||||
const buildIlocBox = (
|
||||
items: { itemId: number; extentOffset: number; extentLength: number }[]
|
||||
items: IlocItem[],
|
||||
{
|
||||
version = 0,
|
||||
baseOffsetSize = 0,
|
||||
indexSize = 0
|
||||
}: { version?: number; baseOffsetSize?: number; indexSize?: number } = {}
|
||||
): Uint8Array => {
|
||||
const perItemSize = 2 + 2 + 0 + 2 + (4 + 4)
|
||||
const bodySize = 4 + 1 + 1 + 2 + items.length * perItemSize
|
||||
const itemCountSize = version < 2 ? 2 : 4
|
||||
const itemIdSize = version < 2 ? 2 : 4
|
||||
const constructionMethodSize = version === 1 || version === 2 ? 2 : 0
|
||||
const itemSizes = items.map((item) => {
|
||||
const extents = item.extents ?? [
|
||||
{
|
||||
extentOffset: item.extentOffset,
|
||||
extentLength: item.extentLength
|
||||
}
|
||||
]
|
||||
return (
|
||||
itemIdSize +
|
||||
constructionMethodSize +
|
||||
2 +
|
||||
baseOffsetSize +
|
||||
2 +
|
||||
extents.length * (indexSize + 4 + 4)
|
||||
)
|
||||
})
|
||||
const bodySize =
|
||||
4 + 1 + 1 + itemCountSize + itemSizes.reduce((sum, size) => sum + size, 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
setU32BE(dv, 0, totalSize)
|
||||
buf.set(new TextEncoder().encode('iloc'), 4)
|
||||
buf[8] = version
|
||||
buf[12] = 0x44
|
||||
buf[13] = 0x00
|
||||
setU16BE(dv, 14, items.length)
|
||||
let p = 16
|
||||
for (const it of items) {
|
||||
setU16BE(dv, p, it.itemId)
|
||||
buf[13] = (baseOffsetSize << 4) | indexSize
|
||||
let p = 14
|
||||
if (version < 2) {
|
||||
setU16BE(dv, p, items.length)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, items.length)
|
||||
p += 4
|
||||
}
|
||||
for (const it of items) {
|
||||
if (version < 2) {
|
||||
setU16BE(dv, p, it.itemId)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, it.itemId)
|
||||
p += 4
|
||||
}
|
||||
if (version === 1 || version === 2) {
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
}
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
setU16BE(dv, p, 1)
|
||||
if (baseOffsetSize > 0) {
|
||||
setU32BE(dv, p, 0)
|
||||
p += baseOffsetSize
|
||||
}
|
||||
const extents = it.extents ?? [
|
||||
{
|
||||
extentOffset: it.extentOffset,
|
||||
extentLength: it.extentLength
|
||||
}
|
||||
]
|
||||
setU16BE(dv, p, extents.length)
|
||||
p += 2
|
||||
setU32BE(dv, p, it.extentOffset)
|
||||
p += 4
|
||||
setU32BE(dv, p, it.extentLength)
|
||||
p += 4
|
||||
for (const extent of extents) {
|
||||
p += indexSize
|
||||
setU32BE(dv, p, extent.extentOffset)
|
||||
p += 4
|
||||
setU32BE(dv, p, extent.extentLength)
|
||||
p += 4
|
||||
}
|
||||
}
|
||||
return buf
|
||||
}
|
||||
@@ -231,7 +312,13 @@ interface BuildAvifOpts {
|
||||
ftypBrand?: string
|
||||
omitMeta?: boolean
|
||||
omitIloc?: boolean
|
||||
iinfVersion?: number
|
||||
infeVersion?: number
|
||||
ilocVersion?: number
|
||||
ilocBaseOffsetSize?: number
|
||||
ilocIndexSize?: number
|
||||
ilocExtents?: Array<{ extentOffset: number; extentLength: number }>
|
||||
rawExifData?: Uint8Array
|
||||
}
|
||||
|
||||
const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
@@ -242,7 +329,13 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
ftypBrand = 'avif',
|
||||
omitMeta = false,
|
||||
omitIloc = false,
|
||||
infeVersion = 2
|
||||
iinfVersion = 0,
|
||||
infeVersion = 2,
|
||||
ilocVersion = 0,
|
||||
ilocBaseOffsetSize = 0,
|
||||
ilocIndexSize = 0,
|
||||
ilocExtents,
|
||||
rawExifData
|
||||
} = opts
|
||||
|
||||
const ftyp = buildFtypBox(ftypBrand)
|
||||
@@ -250,19 +343,43 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
return ftyp.slice().buffer as ArrayBuffer
|
||||
}
|
||||
|
||||
const exifData = buildExifBlob(exifEntries, endian)
|
||||
const exifData = rawExifData ?? buildExifBlob(exifEntries, endian)
|
||||
const infe = buildInfeBox(1, itemType, infeVersion)
|
||||
const iinf = buildIinfBox([infe])
|
||||
const iinf = buildIinfBox([infe], iinfVersion)
|
||||
|
||||
const realIloc = buildIlocBox([
|
||||
{ itemId: 1, extentOffset: 0, extentLength: exifData.length }
|
||||
])
|
||||
const realIloc = buildIlocBox(
|
||||
[
|
||||
{
|
||||
itemId: 1,
|
||||
extentOffset: 0,
|
||||
extentLength: exifData.length,
|
||||
extents: ilocExtents
|
||||
}
|
||||
],
|
||||
{
|
||||
version: ilocVersion,
|
||||
baseOffsetSize: ilocBaseOffsetSize,
|
||||
indexSize: ilocIndexSize
|
||||
}
|
||||
)
|
||||
const metaSize = 8 + 4 + iinf.length + (omitIloc ? 0 : realIloc.length)
|
||||
const exifOffset = ftyp.length + metaSize
|
||||
|
||||
const finalIloc = buildIlocBox([
|
||||
{ itemId: 1, extentOffset: exifOffset, extentLength: exifData.length }
|
||||
])
|
||||
const finalIloc = buildIlocBox(
|
||||
[
|
||||
{
|
||||
itemId: 1,
|
||||
extentOffset: exifOffset,
|
||||
extentLength: exifData.length,
|
||||
extents: ilocExtents
|
||||
}
|
||||
],
|
||||
{
|
||||
version: ilocVersion,
|
||||
baseOffsetSize: ilocBaseOffsetSize,
|
||||
indexSize: ilocIndexSize
|
||||
}
|
||||
)
|
||||
const finalInner = omitIloc ? [iinf] : [iinf, finalIloc]
|
||||
const meta = buildMetaBox(finalInner)
|
||||
|
||||
@@ -319,6 +436,52 @@ describe('getFromAvifFile', () => {
|
||||
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
|
||||
})
|
||||
|
||||
it('extracts EXIF metadata from versioned item info and location boxes', async () => {
|
||||
const workflow = '{"versioned":true}'
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
exifEntries: [`workflow:${workflow}`],
|
||||
iinfVersion: 1,
|
||||
infeVersion: 3,
|
||||
ilocVersion: 2,
|
||||
ilocBaseOffsetSize: 4,
|
||||
ilocIndexSize: 4
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
|
||||
})
|
||||
|
||||
it('returns {} when the Exif item has no extents', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
exifEntries: ['workflow:{}'],
|
||||
ilocExtents: []
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('returns {} when the Exif payload has no TIFF header', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
rawExifData: new TextEncoder().encode('not tiff data')
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result).toEqual({})
|
||||
expect(console.log).toHaveBeenCalledWith(
|
||||
'Warning: TIFF header not found in EXIF data.'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns {} when AVIF major brand is not "avif"', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({ exifEntries: ['workflow:{}'], ftypBrand: 'heic' })
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromWebmFile } from './ebml'
|
||||
|
||||
@@ -16,6 +17,56 @@ const nanFixturePath = path.resolve(
|
||||
__dirname,
|
||||
'__fixtures__/with_nan_metadata.webm'
|
||||
)
|
||||
const WEBM_SIGNATURE = new Uint8Array([0x1a, 0x45, 0xdf, 0xa3])
|
||||
const SIMPLE_TAG = new Uint8Array([0x67, 0xc8])
|
||||
const TAG_NAME = new Uint8Array([0x45, 0xa3])
|
||||
const TAG_VALUE = new Uint8Array([0x44, 0x87])
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
function bytes(...values: number[]) {
|
||||
return new Uint8Array(values)
|
||||
}
|
||||
|
||||
function vint(value: number) {
|
||||
return bytes(0x80 | value)
|
||||
}
|
||||
|
||||
function element(id: Uint8Array, value: string) {
|
||||
const encoded = encoder.encode(value)
|
||||
return concatBytes(id, vint(encoded.length), encoded)
|
||||
}
|
||||
|
||||
function simpleTag(name: string, value: string, useTwoByteSize = false) {
|
||||
const payload = concatBytes(
|
||||
element(TAG_NAME, name),
|
||||
element(TAG_VALUE, value)
|
||||
)
|
||||
const size = useTwoByteSize
|
||||
? bytes(0x40, payload.length)
|
||||
: vint(payload.length)
|
||||
return concatBytes(SIMPLE_TAG, size, payload)
|
||||
}
|
||||
|
||||
async function readWebm(bytes: Uint8Array) {
|
||||
return getFromWebmFile(
|
||||
new File([bytes as Uint8Array<ArrayBuffer>], 'test.webm', {
|
||||
type: 'video/webm'
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
describe('WebM/EBML metadata', () => {
|
||||
it('extracts workflow and prompt from EBML SimpleTag elements', async () => {
|
||||
@@ -46,6 +97,89 @@ describe('WebM/EBML metadata', () => {
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('extracts plain string tags and trims null-terminated values', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(
|
||||
WEBM_SIGNATURE,
|
||||
simpleTag('\0Comment', ' hello \0ignored', true)
|
||||
)
|
||||
)
|
||||
|
||||
expect(result.comment).toBe('hello')
|
||||
})
|
||||
|
||||
it('ignores prompt tags whose value is not complete JSON', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', '{not-json'))
|
||||
)
|
||||
|
||||
expect(result.prompt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('ignores prompt tags whose value has no JSON object', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', 'not-json'))
|
||||
)
|
||||
|
||||
expect(result.prompt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('parses the first complete JSON object from a prompt tag', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(
|
||||
WEBM_SIGNATURE,
|
||||
simpleTag('PROMPT', 'prefix {"outer":{"inner":1}} trailing')
|
||||
)
|
||||
)
|
||||
|
||||
expect(result.prompt).toEqual({ outer: { inner: 1 } })
|
||||
})
|
||||
|
||||
it('ignores tags whose name has no readable text', async () => {
|
||||
const payload = concatBytes(
|
||||
concatBytes(TAG_NAME, vint(2), bytes(0, 1)),
|
||||
element(TAG_VALUE, 'value')
|
||||
)
|
||||
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
|
||||
)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores tag elements with zero-sized names', async () => {
|
||||
const payload = concatBytes(TAG_NAME, vint(0), element(TAG_VALUE, 'value'))
|
||||
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
|
||||
)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores malformed SimpleTag encodings', async () => {
|
||||
const nameOnly = element(TAG_NAME, 'comment')
|
||||
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x00)))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x7f, 0xff)))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(10), TAG_NAME))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(nameOnly.length), nameOnly)
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -60,5 +194,10 @@ describe('WebM/EBML metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromWebmFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromWebmFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
|
||||
@@ -5,7 +6,8 @@ import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
|
||||
import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getGltfBinaryMetadata } from './gltf'
|
||||
|
||||
@@ -16,11 +18,15 @@ describe('GLTF binary metadata parser', () => {
|
||||
return { header, headerView }
|
||||
}
|
||||
|
||||
const createJSONChunk = (jsonData: ArrayBuffer) => {
|
||||
const createJSONChunk = (
|
||||
jsonData: ArrayBuffer,
|
||||
chunkType = ASCII.JSON,
|
||||
chunkLength = jsonData.byteLength
|
||||
) => {
|
||||
const chunkHeader = new ArrayBuffer(GltfSizeBytes.CHUNK_HEADER)
|
||||
const chunkView = new DataView(chunkHeader)
|
||||
chunkView.setUint32(0, jsonData.byteLength, true)
|
||||
chunkView.setUint32(4, ASCII.JSON, true)
|
||||
chunkView.setUint32(0, chunkLength, true)
|
||||
chunkView.setUint32(4, chunkType, true)
|
||||
return chunkHeader
|
||||
}
|
||||
|
||||
@@ -52,13 +58,27 @@ describe('GLTF binary metadata parser', () => {
|
||||
// Builds a GLB whose JSON chunk is the literal text passed in - used to
|
||||
// embed Python generated bare NaN/Infinity tokens that JSON.stringify
|
||||
// would otherwise coerce to null.
|
||||
function createMockGltfFileFromText(jsonText: string): File {
|
||||
interface MockGltfFileOptions {
|
||||
chunkLength?: number
|
||||
chunkType?: number
|
||||
magicNumber?: number
|
||||
}
|
||||
|
||||
function createMockGltfFileFromText(
|
||||
jsonText: string,
|
||||
{
|
||||
chunkLength,
|
||||
chunkType = ASCII.JSON,
|
||||
magicNumber = ASCII.GLTF
|
||||
}: MockGltfFileOptions = {}
|
||||
): File {
|
||||
const jsonData = new TextEncoder().encode(jsonText)
|
||||
const { header, headerView } = createGLTFFileStructure()
|
||||
|
||||
setHeaders(headerView, jsonData.buffer)
|
||||
setTypeHeader(headerView, magicNumber)
|
||||
|
||||
const chunkHeader = createJSONChunk(jsonData.buffer)
|
||||
const chunkHeader = createJSONChunk(jsonData.buffer, chunkType, chunkLength)
|
||||
|
||||
const fileContent = new Uint8Array(
|
||||
header.byteLength + chunkHeader.byteLength + jsonData.byteLength
|
||||
@@ -130,7 +150,9 @@ describe('GLTF binary metadata parser', () => {
|
||||
expect(metadata).toBeDefined()
|
||||
expect(metadata.prompt).toBeDefined()
|
||||
|
||||
const prompt = metadata.prompt as Record<string, any>
|
||||
const prompt: {
|
||||
node1: { class_type: string; inputs: { seed: number } }
|
||||
} = fromAny(metadata.prompt)
|
||||
expect(prompt.node1.class_type).toBe('TestNode')
|
||||
expect(prompt.node1.inputs.seed).toBe(123456)
|
||||
})
|
||||
@@ -179,6 +201,74 @@ describe('GLTF binary metadata parser', () => {
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the GLB magic number is invalid', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { magicNumber: 0 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the GLB is missing the first chunk header', async () => {
|
||||
const { header, headerView } = createGLTFFileStructure()
|
||||
setTypeHeader(headerView, ASCII.GLTF)
|
||||
setVersionHeader(headerView, 2)
|
||||
setTotalLengthHeader(headerView, GltfSizeBytes.HEADER)
|
||||
const mockFile = new File([header], 'header-only.glb')
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the first chunk is not JSON', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { chunkType: 0 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the declared JSON chunk exceeds the buffer', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { chunkLength: 1024 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the JSON chunk cannot be parsed', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{not json')
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when asset extras are missing', async () => {
|
||||
const mockFile = createMockGltfFile({ asset: { version: '2.0' } })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('skips string metadata values that are not valid JSON', async () => {
|
||||
const mockFile = createMockGltfFile({
|
||||
asset: {
|
||||
version: '2.0',
|
||||
extras: {
|
||||
prompt: '{not json',
|
||||
workflow: '{not json'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -193,5 +283,15 @@ describe('GLTF binary metadata parser', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load result is missing', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader result is not a buffer', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', 'not a buffer')
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromIsobmffFile } from './isobmff'
|
||||
|
||||
@@ -16,6 +17,82 @@ const nanFixturePath = path.resolve(
|
||||
__dirname,
|
||||
'__fixtures__/with_nan_metadata.mp4'
|
||||
)
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
function uint32(value: number) {
|
||||
return new Uint8Array([
|
||||
(value >>> 24) & 0xff,
|
||||
(value >>> 16) & 0xff,
|
||||
(value >>> 8) & 0xff,
|
||||
value & 0xff
|
||||
])
|
||||
}
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
function box(type: string, payload = new Uint8Array(), size?: number) {
|
||||
return concatBytes(
|
||||
uint32(size ?? 8 + payload.length),
|
||||
encoder.encode(type),
|
||||
payload
|
||||
)
|
||||
}
|
||||
|
||||
function keyEntry(name: string) {
|
||||
const encoded = encoder.encode(name)
|
||||
return concatBytes(
|
||||
uint32(8 + encoded.length),
|
||||
encoder.encode('mdta'),
|
||||
encoded
|
||||
)
|
||||
}
|
||||
|
||||
function keysBox(names: string[]) {
|
||||
return box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(names.length), ...names.map(keyEntry))
|
||||
)
|
||||
}
|
||||
|
||||
function dataBox(value: string | Uint8Array) {
|
||||
const payload = typeof value === 'string' ? encoder.encode(value) : value
|
||||
return box('data', concatBytes(uint32(0), uint32(0), payload))
|
||||
}
|
||||
|
||||
function ilstItem(index: number, payload: Uint8Array) {
|
||||
return concatBytes(uint32(8 + payload.length), uint32(index), payload)
|
||||
}
|
||||
|
||||
function ilstBox(...items: Uint8Array[]) {
|
||||
return box('ilst', concatBytes(...items))
|
||||
}
|
||||
|
||||
function metaBox(...children: Uint8Array[]) {
|
||||
return box('meta', concatBytes(uint32(0), ...children))
|
||||
}
|
||||
|
||||
function udtaWithMeta(...children: Uint8Array[]) {
|
||||
return box('udta', metaBox(...children))
|
||||
}
|
||||
|
||||
async function readMp4(bytes: Uint8Array) {
|
||||
return getFromIsobmffFile(
|
||||
new File([bytes as Uint8Array<ArrayBuffer>], 'test.mp4', {
|
||||
type: 'video/mp4'
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
describe('ISOBMFF (MP4) metadata', () => {
|
||||
it('extracts workflow and prompt from QuickTime keys/ilst boxes', async () => {
|
||||
@@ -48,6 +125,102 @@ describe('ISOBMFF (MP4) metadata', () => {
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('extracts metadata from udta nested inside moov', async () => {
|
||||
const bytes = box(
|
||||
'moov',
|
||||
udtaWithMeta(
|
||||
keysBox(['WORKFLOW']),
|
||||
ilstBox(ilstItem(1, dataBox('xxxx{"nodes":[]}')))
|
||||
)
|
||||
)
|
||||
|
||||
const result = await readMp4(bytes)
|
||||
|
||||
expect(result.workflow).toEqual({ nodes: [] })
|
||||
})
|
||||
|
||||
it('returns empty when a top-level box declares an impossible size', async () => {
|
||||
const result = await readMp4(box('free', new Uint8Array([1, 2]), 100))
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the keys box cannot provide entries', async () => {
|
||||
const tooShortKeys = box('keys', uint32(0))
|
||||
const missingKeyEntry = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(8))
|
||||
)
|
||||
const malformedKey = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(7), encoder.encode('bad!'))
|
||||
)
|
||||
const oversizedKey = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(100), encoder.encode('bad!'))
|
||||
)
|
||||
|
||||
await expect(readMp4(udtaWithMeta(tooShortKeys))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(missingKeyEntry))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(malformedKey))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(oversizedKey))).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('ignores item entries whose key is unknown or unsupported', async () => {
|
||||
const unknownIndex = udtaWithMeta(
|
||||
keysBox(['PROMPT']),
|
||||
ilstBox(ilstItem(2, dataBox('{"1":{}}')))
|
||||
)
|
||||
const unsupportedKey = udtaWithMeta(
|
||||
keysBox(['DESCRIPTION']),
|
||||
ilstBox(ilstItem(1, dataBox('{"ignored":true}')))
|
||||
)
|
||||
|
||||
await expect(readMp4(unknownIndex)).resolves.toEqual({})
|
||||
await expect(readMp4(unsupportedKey)).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('ignores metadata items without readable JSON data', async () => {
|
||||
const shortDataBox = box('data')
|
||||
const noJson = dataBox('not-json')
|
||||
const invalidJson = dataBox('prefix {not-json')
|
||||
const noDataBox = box('free', new Uint8Array([1]))
|
||||
const invalidItems = ilstBox(
|
||||
concatBytes(uint32(8), uint32(1)),
|
||||
concatBytes(uint32(100), uint32(1), invalidJson)
|
||||
)
|
||||
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, shortDataBox)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noJson))))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, invalidJson)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noDataBox)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(udtaWithMeta(keysBox(['PROMPT']), invalidItems))
|
||||
).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when required metadata boxes are absent', async () => {
|
||||
await expect(readMp4(box('udta', box('free')))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(ilstBox()))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(keysBox(['PROMPT'])))).resolves.toEqual(
|
||||
{}
|
||||
)
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -63,5 +236,10 @@ describe('ISOBMFF (MP4) metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromIsobmffFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromIsobmffFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
127
src/scripts/metadata/parser.test.ts
Normal file
127
src/scripts/metadata/parser.test.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getFromWebmFile } from '@/scripts/metadata/ebml'
|
||||
import { getGltfBinaryMetadata } from '@/scripts/metadata/gltf'
|
||||
import { getFromIsobmffFile } from '@/scripts/metadata/isobmff'
|
||||
import { getDataFromJSON } from '@/scripts/metadata/json'
|
||||
import { getMp3Metadata } from '@/scripts/metadata/mp3'
|
||||
import { getOggMetadata } from '@/scripts/metadata/ogg'
|
||||
import { getWorkflowDataFromFile } from '@/scripts/metadata/parser'
|
||||
import { getSvgMetadata } from '@/scripts/metadata/svg'
|
||||
import {
|
||||
getAvifMetadata,
|
||||
getFlacMetadata,
|
||||
getLatentMetadata,
|
||||
getPngMetadata,
|
||||
getWebpMetadata
|
||||
} from '@/scripts/pnginfo'
|
||||
|
||||
vi.mock('@/scripts/metadata/ebml', () => ({ getFromWebmFile: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/gltf', () => ({ getGltfBinaryMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/isobmff', () => ({ getFromIsobmffFile: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/json', () => ({ getDataFromJSON: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/mp3', () => ({ getMp3Metadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/ogg', () => ({ getOggMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/svg', () => ({ getSvgMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/pnginfo', () => ({
|
||||
getAvifMetadata: vi.fn(),
|
||||
getFlacMetadata: vi.fn(),
|
||||
getLatentMetadata: vi.fn(),
|
||||
getPngMetadata: vi.fn(),
|
||||
getWebpMetadata: vi.fn()
|
||||
}))
|
||||
|
||||
function file(type: string, name = 'file') {
|
||||
return new File(['data'], name, { type })
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getWorkflowDataFromFile', () => {
|
||||
it('routes png/avif/mp3/ogg/webm to their parsers and returns the result', async () => {
|
||||
vi.mocked(getPngMetadata).mockResolvedValue({ a: 1 } as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/png'))).toEqual({ a: 1 })
|
||||
expect(getPngMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('image/avif'))
|
||||
expect(getAvifMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('audio/mpeg'))
|
||||
expect(getMp3Metadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('audio/ogg'))
|
||||
expect(getOggMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('video/webm'))
|
||||
expect(getFromWebmFile).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('extracts workflow/prompt from webp, preferring lowercase keys', async () => {
|
||||
vi.mocked(getWebpMetadata).mockResolvedValue({
|
||||
workflow: 'wf',
|
||||
prompt: 'pr'
|
||||
} as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
|
||||
workflow: 'wf',
|
||||
prompt: 'pr'
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to capitalized webp keys when lowercase are absent', async () => {
|
||||
vi.mocked(getWebpMetadata).mockResolvedValue({
|
||||
Workflow: 'WF',
|
||||
Prompt: 'PR'
|
||||
} as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
|
||||
workflow: 'WF',
|
||||
prompt: 'PR'
|
||||
})
|
||||
})
|
||||
|
||||
it('handles both flac mime types and extracts workflow/prompt', async () => {
|
||||
vi.mocked(getFlacMetadata).mockResolvedValue({ workflow: 'w' } as never)
|
||||
expect(await getWorkflowDataFromFile(file('audio/flac'))).toEqual({
|
||||
workflow: 'w',
|
||||
prompt: undefined
|
||||
})
|
||||
expect(await getWorkflowDataFromFile(file('audio/x-flac'))).toEqual({
|
||||
workflow: 'w',
|
||||
prompt: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('routes isobmff by mime type and by file extension', async () => {
|
||||
await getWorkflowDataFromFile(file('video/mp4'))
|
||||
await getWorkflowDataFromFile(file('', 'clip.mov'))
|
||||
await getWorkflowDataFromFile(file('', 'clip.m4v'))
|
||||
expect(getFromIsobmffFile).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('routes svg and gltf by mime type or extension', async () => {
|
||||
await getWorkflowDataFromFile(file('image/svg+xml'))
|
||||
await getWorkflowDataFromFile(file('', 'icon.svg'))
|
||||
expect(getSvgMetadata).toHaveBeenCalledTimes(2)
|
||||
|
||||
await getWorkflowDataFromFile(file('model/gltf-binary'))
|
||||
await getWorkflowDataFromFile(file('', 'model.glb'))
|
||||
expect(getGltfBinaryMetadata).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('routes latent/safetensors and json by extension or mime type', async () => {
|
||||
await getWorkflowDataFromFile(file('', 'x.latent'))
|
||||
await getWorkflowDataFromFile(file('', 'x.safetensors'))
|
||||
expect(getLatentMetadata).toHaveBeenCalledTimes(2)
|
||||
|
||||
await getWorkflowDataFromFile(file('application/json'))
|
||||
await getWorkflowDataFromFile(file('', 'x.json'))
|
||||
expect(getDataFromJSON).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('returns undefined for an unrecognized file', async () => {
|
||||
expect(
|
||||
await getWorkflowDataFromFile(file('application/zip', 'a.zip'))
|
||||
).toBe(undefined)
|
||||
})
|
||||
})
|
||||
@@ -2,7 +2,8 @@ import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromPngBuffer, getFromPngFile } from './png'
|
||||
|
||||
@@ -191,6 +192,27 @@ describe('getFromPngBuffer', () => {
|
||||
const result = await getFromPngBuffer(buffer)
|
||||
expect(result['workflow']).toBe(workflow)
|
||||
})
|
||||
|
||||
it('logs error and skips compressed iTXt with invalid deflate data', async () => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const buffer = createPngWithChunk(
|
||||
'iTXt',
|
||||
'workflow',
|
||||
new Uint8Array([1, 2, 3]),
|
||||
{
|
||||
compressionFlag: 1,
|
||||
compressionMethod: 0
|
||||
}
|
||||
)
|
||||
|
||||
const result = await getFromPngBuffer(buffer)
|
||||
|
||||
expect(result['workflow']).toBeUndefined()
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to decompress iTXt chunk "workflow"'),
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFromPngFile', () => {
|
||||
@@ -228,5 +250,12 @@ describe('getFromPngFile', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
await expect(getFromPngFile(file)).rejects.toThrow('FileReader aborted')
|
||||
})
|
||||
|
||||
it('rejects when the FileReader load has no ArrayBuffer result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
await expect(getFromPngFile(file)).rejects.toThrow(
|
||||
'Failed to read file as ArrayBuffer'
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
import type { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import { api } from '@/scripts/api'
|
||||
import { getFromAvifFile } from './metadata/avif'
|
||||
import { getFromFlacFile } from './metadata/flac'
|
||||
import { getFromPngFile } from './metadata/png'
|
||||
import {
|
||||
getAvifMetadata,
|
||||
getFlacMetadata,
|
||||
importA1111,
|
||||
getLatentMetadata,
|
||||
getPngMetadata,
|
||||
getWebpMetadata
|
||||
@@ -56,6 +62,20 @@ function encodeAsciiIfd(entries: AsciiIfdEntry[]): Uint8Array {
|
||||
return buf
|
||||
}
|
||||
|
||||
function encodeNonAsciiIfdEntry(tag: number): Uint8Array {
|
||||
const buf = new Uint8Array(22)
|
||||
const dv = new DataView(buf.buffer)
|
||||
buf.set([0x49, 0x49], 0)
|
||||
dv.setUint16(2, 0x002a, true)
|
||||
dv.setUint32(4, 8, true)
|
||||
dv.setUint16(8, 1, true)
|
||||
dv.setUint16(10, tag, true)
|
||||
dv.setUint16(12, 3, true)
|
||||
dv.setUint32(14, 1, true)
|
||||
dv.setUint32(18, 123, true)
|
||||
return buf
|
||||
}
|
||||
|
||||
type WebpChunk = { type: string; payload: Uint8Array }
|
||||
|
||||
function wrapInWebp(chunks: WebpChunk[]): File {
|
||||
@@ -157,6 +177,16 @@ describe('getWebpMetadata', () => {
|
||||
|
||||
expect(metadata).toEqual({ workflow: '{"a":1}' })
|
||||
})
|
||||
|
||||
it('ignores EXIF entries that are not ASCII strings', async () => {
|
||||
const file = wrapInWebp([
|
||||
{ type: 'EXIF', payload: encodeNonAsciiIfdEntry(270) }
|
||||
])
|
||||
|
||||
const metadata = await getWebpMetadata(file)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getLatentMetadata', () => {
|
||||
@@ -234,3 +264,313 @@ describe('format-specific metadata wrappers', () => {
|
||||
expect(result).toEqual({ workflow: '{"avif":1}' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('importA1111', () => {
|
||||
function widget(
|
||||
name: string,
|
||||
options: string[] = []
|
||||
): IBaseWidget & { value?: string | number } {
|
||||
return fromPartial<IBaseWidget & { value?: string | number }>({
|
||||
name,
|
||||
options: { values: options },
|
||||
value: undefined
|
||||
})
|
||||
}
|
||||
|
||||
function createNode(type: string): LGraphNode {
|
||||
const widgetsByType: Record<string, IBaseWidget[]> = {
|
||||
CheckpointLoaderSimple: [widget('ckpt_name', ['sd15.safetensors'])],
|
||||
CLIPSetLastLayer: [widget('stop_at_clip_layer')],
|
||||
CLIPTextEncode: [widget('text')],
|
||||
EmptyLatentImage: [widget('width'), widget('height')],
|
||||
ImageScale: [widget('width'), widget('height')],
|
||||
ImageUpscaleWithModel: [],
|
||||
KSampler: [
|
||||
widget('cfg'),
|
||||
widget('sampler_name', ['euler_a', 'dpmpp_2m']),
|
||||
widget('scheduler', ['normal', 'karras']),
|
||||
widget('seed'),
|
||||
widget('steps'),
|
||||
widget('denoise')
|
||||
],
|
||||
LatentUpscale: [
|
||||
widget('upscale_method', ['nearest-exact']),
|
||||
widget('width'),
|
||||
widget('height')
|
||||
],
|
||||
LoraLoader: [
|
||||
widget('lora_name', ['foo.safetensors']),
|
||||
widget('strength_model'),
|
||||
widget('strength_clip')
|
||||
],
|
||||
UpscaleModelLoader: [widget('model_name', ['ESRGAN'])],
|
||||
VAEEncodeTiled: [],
|
||||
VAEDecodeTiled: [],
|
||||
SaveImage: [],
|
||||
VAEDecode: []
|
||||
}
|
||||
return {
|
||||
type,
|
||||
widgets: widgetsByType[type] ?? [],
|
||||
connect: vi.fn()
|
||||
} as unknown as LGraphNode
|
||||
}
|
||||
|
||||
function createGraph(): LGraph {
|
||||
return fromPartial<LGraph>({
|
||||
add: vi.fn(),
|
||||
arrange: vi.fn(),
|
||||
clear: vi.fn()
|
||||
})
|
||||
}
|
||||
|
||||
function findWidget(node: LGraphNode, name: string) {
|
||||
return node.widgets?.find((widget) => widget.name === name)
|
||||
}
|
||||
|
||||
it('ignores text without parsed generation settings', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
|
||||
await importA1111(graph, 'positive prompt only')
|
||||
await importA1111(graph, 'positive prompt\nSteps:\n')
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('ignores text without a negative prompt section', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
['positive prompt only', 'Steps: 20, Sampler: Euler a'].join('\n')
|
||||
)
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('stops when a required base node cannot be created', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) =>
|
||||
type === 'KSampler' ? null : createNode(type)
|
||||
)
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'prompt',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Euler a, Size: 512x512'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Failed to create required nodes for A1111 import'
|
||||
)
|
||||
})
|
||||
|
||||
it('builds a basic graph from A1111 parameters', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue(['easynegative'])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'<lora:foo:0.7> portrait easynegative',
|
||||
'Negative prompt: blurry <lora:bad:not-number>',
|
||||
'Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 42, Size: 512x512, Model: sd15, Clip skip: 2, Model hash: ignored'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const checkpoint = nodes.find(
|
||||
(node) => node.type === 'CheckpointLoaderSimple'
|
||||
)
|
||||
const clipSkip = nodes.find((node) => node.type === 'CLIPSetLastLayer')
|
||||
const sampler = nodes.find((node) => node.type === 'KSampler')
|
||||
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
|
||||
const lora = nodes.find((node) => node.type === 'LoraLoader')
|
||||
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
|
||||
|
||||
expect(graph.clear).toHaveBeenCalledOnce()
|
||||
expect(graph.arrange).toHaveBeenCalledOnce()
|
||||
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('sd15.safetensors')
|
||||
expect(findWidget(clipSkip!, 'stop_at_clip_layer')?.value).toBe(-2)
|
||||
expect(findWidget(sampler!, 'cfg')?.value).toBe(7)
|
||||
expect(findWidget(sampler!, 'sampler_name')?.value).toBe('euler_a')
|
||||
expect(findWidget(sampler!, 'scheduler')?.value).toBe('normal')
|
||||
expect(findWidget(sampler!, 'seed')?.value).toBe(42)
|
||||
expect(findWidget(sampler!, 'steps')?.value).toBe(20)
|
||||
expect(findWidget(image!, 'width')?.value).toBe(512)
|
||||
expect(findWidget(image!, 'height')?.value).toBe(512)
|
||||
expect(findWidget(lora!, 'lora_name')?.value).toBe('foo.safetensors')
|
||||
expect(findWidget(lora!, 'strength_model')?.value).toBe(0.7)
|
||||
expect(findWidget(lora!, 'strength_clip')?.value).toBe(0.7)
|
||||
expect(findWidget(textNodes[0], 'text')?.value).toBe(
|
||||
' portrait embedding:easynegative'
|
||||
)
|
||||
expect(findWidget(textNodes[1], 'text')?.value).toBe('blurry ')
|
||||
})
|
||||
|
||||
it('keeps unknown option-prefix values and logs the mismatch', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Unknown Sampler, CFG scale: 7, Seed: 42, Size: 512x512, Model: unknown-model'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const checkpoint = nodes.find(
|
||||
(node) => node.type === 'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('unknown-model')
|
||||
expect(console.warn).toHaveBeenCalledWith(
|
||||
"Unknown value 'unknown-model' for widget 'ckpt_name'",
|
||||
checkpoint
|
||||
)
|
||||
})
|
||||
|
||||
it('skips missing LoraLoader nodes while keeping prompt text cleaned', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
if (type === 'LoraLoader') return null
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'<lora:missing:0.5> portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Euler a, Size: 512x512'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
|
||||
expect(findWidget(textNodes[0], 'text')?.value).toBe(' portrait')
|
||||
})
|
||||
|
||||
it('returns from latent hires setup when LatentUpscale cannot be created', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
if (type === 'LatentUpscale') return null
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 8, Sampler: Euler a, Size: 512x512, Hires upscale: 2, Hires upscaler: Latent'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
expect(nodes.some((node) => node.type === 'KSampler')).toBe(true)
|
||||
expect(nodes.some((node) => node.type === 'LatentUpscale')).toBe(false)
|
||||
})
|
||||
|
||||
it('builds a latent hires pass with explicit resize and denoise settings', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 12, Sampler: DPM++ 2M Karras, CFG scale: 5, Seed: 1, Size: 513x577, Model: sd15, Hires resize: 1025x1089, Hires steps: 4, Hires upscaler: Latent (nearest-exact), Denoising strength: 0.35'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
|
||||
const latentUpscale = nodes.find((node) => node.type === 'LatentUpscale')
|
||||
const samplers = nodes.filter((node) => node.type === 'KSampler')
|
||||
|
||||
expect(findWidget(image!, 'width')?.value).toBe(576)
|
||||
expect(findWidget(image!, 'height')?.value).toBe(640)
|
||||
expect(findWidget(latentUpscale!, 'upscale_method')?.value).toBe(
|
||||
'nearest-exact'
|
||||
)
|
||||
expect(findWidget(latentUpscale!, 'width')?.value).toBe(1088)
|
||||
expect(findWidget(latentUpscale!, 'height')?.value).toBe(1152)
|
||||
expect(findWidget(samplers[0], 'scheduler')?.value).toBe('karras')
|
||||
expect(findWidget(samplers[0], 'sampler_name')?.value).toBe('dpmpp_2m')
|
||||
expect(findWidget(samplers[1], 'steps')?.value).toBe(4)
|
||||
expect(findWidget(samplers[1], 'cfg')?.value).toBe(5)
|
||||
expect(findWidget(samplers[1], 'scheduler')?.value).toBe('karras')
|
||||
expect(findWidget(samplers[1], 'sampler_name')?.value).toBe('dpmpp_2m')
|
||||
expect(findWidget(samplers[1], 'denoise')?.value).toBe(0.35)
|
||||
})
|
||||
|
||||
it('builds an image upscaler hires pass with fallback steps and denoise', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 8, Sampler: Euler a, CFG scale: 6, Seed: 2, Size: 512x512, Model: sd15, Hires upscale: 1.5, Hires upscaler: ESRGAN'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const upscaleLoader = nodes.find(
|
||||
(node) => node.type === 'UpscaleModelLoader'
|
||||
)
|
||||
const imageScale = nodes.find((node) => node.type === 'ImageScale')
|
||||
const samplers = nodes.filter((node) => node.type === 'KSampler')
|
||||
|
||||
expect(findWidget(upscaleLoader!, 'model_name')?.value).toBe('ESRGAN')
|
||||
expect(findWidget(imageScale!, 'width')?.value).toBe(768)
|
||||
expect(findWidget(imageScale!, 'height')?.value).toBe(768)
|
||||
expect(findWidget(samplers[1], 'steps')?.value).toBe(8)
|
||||
expect(findWidget(samplers[1], 'denoise')?.value).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
450
src/scripts/ui.test.ts
Normal file
450
src/scripts/ui.test.ts
Normal file
@@ -0,0 +1,450 @@
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { mockSettingStore } = vi.hoisted(() => ({
|
||||
mockSettingStore: {
|
||||
get: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useRunButtonTelemetry', () => ({
|
||||
useRunButtonTelemetry: () => ({ trackRunButton: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => ({
|
||||
extractWorkflow: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/composables/useSettingsDialog', () => ({
|
||||
useSettingsDialog: () => ({ show: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => mockSettingStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackWorkflowExecution: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ resetView: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({ execute: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/workspaceStore', () => ({
|
||||
useWorkspaceStore: () => ({ focusMode: false })
|
||||
}))
|
||||
|
||||
vi.mock('./api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn(),
|
||||
clearItems: vi.fn(),
|
||||
deleteItem: vi.fn(),
|
||||
dispatchCustomEvent: vi.fn(),
|
||||
getHistory: vi.fn(),
|
||||
getJobDetail: vi.fn(),
|
||||
getQueue: vi.fn(),
|
||||
interrupt: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./app', () => ({
|
||||
app: {
|
||||
clean: vi.fn(),
|
||||
handleFile: vi.fn(),
|
||||
loadGraphData: vi.fn(),
|
||||
openClipspace: vi.fn(),
|
||||
queuePrompt: vi.fn(),
|
||||
refreshComboInNodes: vi.fn(() => Promise.resolve())
|
||||
},
|
||||
ComfyApp: class ComfyApp {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/dialog', () => ({
|
||||
ComfyDialog: class ComfyDialog {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/settings', () => ({
|
||||
ComfySettingsDialog: class ComfySettingsDialog {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/toggleSwitch', () => ({
|
||||
toggleSwitch: () => document.createElement('div')
|
||||
}))
|
||||
|
||||
import { extractWorkflow } from '@/platform/remote/comfyui/jobs/fetchJobs'
|
||||
|
||||
import { api } from './api'
|
||||
import { app } from './app'
|
||||
import { $el, ComfyUI } from './ui'
|
||||
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
localStorage.clear()
|
||||
vi.clearAllMocks()
|
||||
mockSettingStore.get.mockReturnValue(undefined)
|
||||
Object.assign(app, {
|
||||
lastExecutionError: undefined,
|
||||
nodeOutputs: undefined
|
||||
})
|
||||
})
|
||||
|
||||
async function click(button: HTMLButtonElement) {
|
||||
const handler = button.onclick
|
||||
expect(handler).toBeTypeOf('function')
|
||||
await Promise.resolve(handler?.call(button, fromAny(new MouseEvent('click'))))
|
||||
}
|
||||
|
||||
function buttonByText(root: ParentNode, text: string): HTMLButtonElement {
|
||||
const button = [...root.querySelectorAll('button')].find(
|
||||
(candidate) => candidate.textContent === text
|
||||
)
|
||||
if (!button) throw new Error(`Missing button: ${text}`)
|
||||
return button
|
||||
}
|
||||
|
||||
describe('$el', () => {
|
||||
it('creates elements with classes, children, props, and callbacks', () => {
|
||||
const parent = document.createElement('section')
|
||||
const child = document.createElement('span')
|
||||
const callback = vi.fn()
|
||||
|
||||
const element = $el(
|
||||
'label.primary.secondary',
|
||||
{
|
||||
parent,
|
||||
$: callback,
|
||||
dataset: { role: 'name' },
|
||||
for: 'target-input',
|
||||
style: { display: 'block' },
|
||||
title: 'Label'
|
||||
},
|
||||
child
|
||||
)
|
||||
|
||||
expect(element.tagName).toBe('LABEL')
|
||||
expect(element.classList.contains('primary')).toBe(true)
|
||||
expect(element.classList.contains('secondary')).toBe(true)
|
||||
expect(element.dataset.role).toBe('name')
|
||||
expect(element.getAttribute('for')).toBe('target-input')
|
||||
expect(element.style.display).toBe('block')
|
||||
expect(element.title).toBe('Label')
|
||||
expect(element.firstElementChild).toBe(child)
|
||||
expect(parent.firstElementChild).toBe(element)
|
||||
expect(callback).toHaveBeenCalledWith(element)
|
||||
})
|
||||
|
||||
it('accepts string and single-element children shorthands', () => {
|
||||
const textElement = $el('button', 'Run')
|
||||
const child = document.createElement('strong')
|
||||
const wrapper = $el('div', child)
|
||||
|
||||
expect(textElement.textContent).toBe('Run')
|
||||
expect(wrapper.firstElementChild).toBe(child)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyUI legacy menu', () => {
|
||||
it('loads queue items and runs list actions', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: [{ id: 'pending', priority: 2 }]
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue({
|
||||
outputs: { node: { images: ['image.png'] } }
|
||||
} as never)
|
||||
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(api.getJobDetail).toHaveBeenCalledWith('running')
|
||||
expect(extractWorkflow).toHaveBeenCalled()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
|
||||
expect(app.nodeOutputs).toEqual({ node: { images: ['image.png'] } })
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Cancel'))
|
||||
expect(api.interrupt).toHaveBeenCalledWith('running')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Delete'))
|
||||
expect(api.deleteItem).toHaveBeenCalledWith('queue', 'pending')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Clear Queue'))
|
||||
expect(api.clearItems).toHaveBeenCalledWith('queue')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Refresh'))
|
||||
expect(api.getQueue).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips loading queue items when job details are unavailable', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: []
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue(null as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(extractWorkflow).not.toHaveBeenCalled()
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('loads queue item workflows without outputs', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: []
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue({ id: 'running' } as never)
|
||||
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
|
||||
expect(app.nodeOutputs).toBeUndefined()
|
||||
})
|
||||
|
||||
it('loads history in reverse order', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValue([
|
||||
{ id: 'old', priority: 1 },
|
||||
{ id: 'new', priority: 2 }
|
||||
] as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.history.show()
|
||||
|
||||
expect(ui.history.element.textContent).toContain('2: LoadDelete')
|
||||
expect(ui.history.element.textContent?.indexOf('2:')).toBeLessThan(
|
||||
ui.history.element.textContent?.indexOf('1:') ?? Number.MAX_SAFE_INTEGER
|
||||
)
|
||||
})
|
||||
|
||||
it('updates queue status and auto-queues when enabled', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'instant'
|
||||
ui.lastQueueSize = 1
|
||||
ui.batchCount = 3
|
||||
|
||||
ui.setStatus({ exec_info: { queue_remaining: 0 } })
|
||||
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 3)
|
||||
expect(ui.queueSize.textContent).toBe('Queue size: 0')
|
||||
expect(ui.lastQueueSize).toBe(3)
|
||||
|
||||
ui.setStatus(null)
|
||||
expect(ui.queueSize.textContent).toBe('Queue size: ERR')
|
||||
})
|
||||
|
||||
it('does not auto-queue while a prior execution error is present', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'instant'
|
||||
ui.lastQueueSize = 1
|
||||
Object.assign(app, { lastExecutionError: new Error('failed') })
|
||||
|
||||
ui.setStatus({ exec_info: { queue_remaining: 0 } })
|
||||
|
||||
expect(app.queuePrompt).not.toHaveBeenCalled()
|
||||
expect(ui.lastQueueSize).toBe(0)
|
||||
})
|
||||
|
||||
it('tracks graph changes for change-mode auto queueing', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
const graphChanged = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find(([eventName]) => eventName === 'graphChanged')?.[1]
|
||||
if (!graphChanged) throw new Error('Missing graphChanged listener')
|
||||
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'change'
|
||||
ui.lastQueueSize = 1
|
||||
graphChanged(fromAny(new CustomEvent('graphChanged')))
|
||||
|
||||
expect(ui.graphHasChanged).toBe(true)
|
||||
|
||||
ui.lastQueueSize = 0
|
||||
graphChanged(fromAny(new CustomEvent('graphChanged')))
|
||||
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
|
||||
expect(ui.graphHasChanged).toBe(false)
|
||||
})
|
||||
|
||||
it('wires primary menu buttons to app and command actions', async () => {
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'Queue Prompt'))
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
|
||||
|
||||
await click(buttonByText(document, 'Queue Front'))
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(-1, 1)
|
||||
|
||||
await click(buttonByText(document, 'Save'))
|
||||
await click(buttonByText(document, 'Save (API Format)'))
|
||||
await click(buttonByText(document, 'Refresh'))
|
||||
await click(buttonByText(document, 'Clipspace'))
|
||||
await click(buttonByText(document, 'Clear'))
|
||||
await click(buttonByText(document, 'Load Default'))
|
||||
await click(buttonByText(document, 'Reset View'))
|
||||
|
||||
expect(app.refreshComboInNodes).toHaveBeenCalled()
|
||||
expect(app.openClipspace).toHaveBeenCalled()
|
||||
expect(app.clean).toHaveBeenCalled()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith()
|
||||
expect(ui.menuContainer.style.display).toBe('none')
|
||||
})
|
||||
|
||||
it('wires file input and legacy option controls', async () => {
|
||||
const ui = new ComfyUI(app)
|
||||
const file = new File(['{}'], 'workflow.json', {
|
||||
type: 'application/json'
|
||||
})
|
||||
const fileInput = document.getElementById(
|
||||
'comfy-file-input'
|
||||
) as HTMLInputElement
|
||||
Object.defineProperty(fileInput, 'files', {
|
||||
configurable: true,
|
||||
value: [file]
|
||||
})
|
||||
|
||||
await Promise.resolve(
|
||||
fileInput.onchange?.call(fileInput, new Event('change'))
|
||||
)
|
||||
expect(app.handleFile).toHaveBeenCalledWith(file, 'file_button')
|
||||
expect(fileInput.value).toBe('')
|
||||
|
||||
const range = document.getElementById(
|
||||
'batchCountInputRange'
|
||||
) as HTMLInputElement
|
||||
const number = document.getElementById(
|
||||
'batchCountInputNumber'
|
||||
) as HTMLInputElement
|
||||
range.value = '4'
|
||||
const extraOptionsCheckbox = document.querySelector(
|
||||
'label input[type="checkbox"]'
|
||||
) as HTMLInputElement
|
||||
extraOptionsCheckbox.checked = true
|
||||
extraOptionsCheckbox.onchange?.call(
|
||||
extraOptionsCheckbox,
|
||||
fromPartial({ srcElement: extraOptionsCheckbox })
|
||||
)
|
||||
expect(ui.batchCount).toBe(4)
|
||||
expect(document.getElementById('extraOptions')?.style.display).toBe('block')
|
||||
|
||||
number.value = '7'
|
||||
number.oninput?.call(number, fromPartial({ target: number }))
|
||||
expect(range.value).toBe('7')
|
||||
|
||||
range.value = '9'
|
||||
range.oninput?.call(range, fromPartial({ srcElement: range }))
|
||||
expect(number.value).toBe('9')
|
||||
|
||||
const autoQueueCheckbox = document.getElementById(
|
||||
'autoQueueCheckbox'
|
||||
) as HTMLInputElement
|
||||
autoQueueCheckbox.checked = true
|
||||
autoQueueCheckbox.onchange?.call(
|
||||
autoQueueCheckbox,
|
||||
fromPartial({ target: autoQueueCheckbox })
|
||||
)
|
||||
expect(ui.autoQueueEnabled).toBe(true)
|
||||
|
||||
extraOptionsCheckbox.checked = false
|
||||
extraOptionsCheckbox.onchange?.call(
|
||||
extraOptionsCheckbox,
|
||||
fromPartial({ srcElement: extraOptionsCheckbox })
|
||||
)
|
||||
expect(ui.batchCount).toBe(1)
|
||||
expect(ui.autoQueueEnabled).toBe(false)
|
||||
expect(document.getElementById('extraOptions')?.style.display).toBe('none')
|
||||
})
|
||||
|
||||
it('toggles queue visibility through the menu button', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [],
|
||||
Pending: []
|
||||
} as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'View Queue'))
|
||||
expect(ui.queue.visible).toBe(true)
|
||||
|
||||
await click(buttonByText(document, 'Close'))
|
||||
expect(ui.queue.visible).toBe(false)
|
||||
})
|
||||
|
||||
it('does not clear or load defaults when confirmation is declined', async () => {
|
||||
mockSettingStore.get.mockReturnValue(true)
|
||||
vi.stubGlobal(
|
||||
'confirm',
|
||||
vi.fn(() => false)
|
||||
)
|
||||
try {
|
||||
new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'Clear'))
|
||||
await click(buttonByText(document, 'Load Default'))
|
||||
|
||||
expect(app.clean).not.toHaveBeenCalled()
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
vi.unstubAllGlobals()
|
||||
}
|
||||
})
|
||||
|
||||
it('persists manual menu dragging', () => {
|
||||
Object.defineProperty(document.body, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 1000
|
||||
})
|
||||
Object.defineProperty(document.body, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: 800
|
||||
})
|
||||
const ui = new ComfyUI(app)
|
||||
ui.menuContainer.style.display = 'block'
|
||||
Object.defineProperty(ui.menuContainer, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 100
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: 80
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'offsetLeft', {
|
||||
configurable: true,
|
||||
get: () => 700
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'offsetTop', {
|
||||
configurable: true,
|
||||
get: () => 20
|
||||
})
|
||||
const handle = ui.menuContainer.querySelector('.drag-handle') as HTMLElement
|
||||
|
||||
handle.onmousedown?.(
|
||||
new MouseEvent('mousedown', { clientX: 10, clientY: 10 })
|
||||
)
|
||||
document.onmousemove?.(
|
||||
new MouseEvent('mousemove', { clientX: 20, clientY: 30 })
|
||||
)
|
||||
document.onmouseup?.(new MouseEvent('mouseup'))
|
||||
|
||||
expect(ui.menuContainer.classList.contains('comfy-menu-manual-pos')).toBe(
|
||||
true
|
||||
)
|
||||
expect(ui.menuContainer.style.right).toBe('190px')
|
||||
expect(localStorage.getItem('Comfy.MenuPosition')).toBe(
|
||||
JSON.stringify({ x: 700, y: 20 })
|
||||
)
|
||||
expect(document.onmousemove).toBeNull()
|
||||
expect(document.onmouseup).toBeNull()
|
||||
})
|
||||
})
|
||||
199
src/scripts/ui/components/button.test.ts
Normal file
199
src/scripts/ui/components/button.test.ts
Normal file
@@ -0,0 +1,199 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyApp } from '@/scripts/app'
|
||||
|
||||
vi.mock('../../ui', () => ({
|
||||
$el: (tag: string, props?: Record<string, unknown>, children?: Node[]) => {
|
||||
const [tagName, ...classes] = tag.split('.')
|
||||
const element = document.createElement(tagName)
|
||||
if (classes.length) element.classList.add(...classes)
|
||||
if (props) {
|
||||
const listeners = Object.entries(props).filter(([key]) =>
|
||||
key.startsWith('on')
|
||||
)
|
||||
for (const [key, listener] of listeners) {
|
||||
if (typeof listener === 'function') {
|
||||
element.addEventListener(key.slice(2), listener as EventListener)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (children) element.append(...children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
prop: <T>(
|
||||
target: object,
|
||||
name: string,
|
||||
defaultValue: T,
|
||||
onChanged?: (currentValue: T, previousValue: T) => void
|
||||
) => {
|
||||
let currentValue: T
|
||||
Object.defineProperty(target, name, {
|
||||
get() {
|
||||
return currentValue
|
||||
},
|
||||
set(newValue: T) {
|
||||
const previousValue = currentValue
|
||||
currentValue = newValue
|
||||
onChanged?.(currentValue, previousValue)
|
||||
}
|
||||
})
|
||||
;(target as Record<string, T>)[name] = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
}))
|
||||
|
||||
import { ComfyButton } from './button'
|
||||
|
||||
class MockPopup extends EventTarget {
|
||||
element = document.createElement('div')
|
||||
open = false
|
||||
toggle = vi.fn(() => {
|
||||
this.open = !this.open
|
||||
this.dispatchEvent(new CustomEvent('change'))
|
||||
})
|
||||
}
|
||||
|
||||
function mockApp(settingValue: boolean) {
|
||||
let listener: (() => void) | undefined
|
||||
const app = {
|
||||
ui: {
|
||||
settings: {
|
||||
getSettingValue: vi.fn(() => settingValue),
|
||||
addEventListener: vi.fn((_event: string, callback: () => void) => {
|
||||
listener = callback
|
||||
})
|
||||
}
|
||||
}
|
||||
} as unknown as ComfyApp
|
||||
return {
|
||||
app,
|
||||
setSettingValue(value: boolean) {
|
||||
settingValue = value
|
||||
listener?.()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
describe('ComfyButton', () => {
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('renders icon, content, tooltip, enabled state, and click action', () => {
|
||||
const action = vi.fn()
|
||||
const button = new ComfyButton({
|
||||
icon: 'play',
|
||||
overIcon: 'pause',
|
||||
iconSize: 18,
|
||||
content: 'Run',
|
||||
tooltip: 'Queue prompt',
|
||||
enabled: false,
|
||||
classList: { primary: true, hiddenClass: false },
|
||||
action
|
||||
})
|
||||
|
||||
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
|
||||
expect(button.contentElement.textContent).toBe('Run')
|
||||
expect(button.element.title).toBe('Queue prompt')
|
||||
expect(button.element.getAttribute('aria-label')).toBe('Queue prompt')
|
||||
expect(button.element.classList.contains('primary')).toBe(true)
|
||||
expect(button.element.classList.contains('disabled')).toBe(true)
|
||||
expect((button.element as HTMLButtonElement).disabled).toBe(true)
|
||||
|
||||
button.enabled = true
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
expect(button.iconElement.className).toBe('mdi mdi-pause mdi-18px')
|
||||
button.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
expect(action).toHaveBeenCalledWith(expect.any(MouseEvent), button)
|
||||
})
|
||||
|
||||
it('supports HTMLElement content and removing tooltip text', () => {
|
||||
const button = new ComfyButton({ content: 'Text', tooltip: 'Hint' })
|
||||
const content = document.createElement('strong')
|
||||
content.textContent = 'Element'
|
||||
|
||||
button.content = content
|
||||
button.tooltip = ''
|
||||
|
||||
expect(button.contentElement.firstElementChild).toBe(content)
|
||||
expect(button.element.hasAttribute('title')).toBe(false)
|
||||
})
|
||||
|
||||
it('updates the hover icon when overIcon changes while hovered', () => {
|
||||
const button = new ComfyButton({ icon: 'play' })
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.overIcon = 'pause'
|
||||
|
||||
expect(button.iconElement.className).toBe('mdi mdi-pause')
|
||||
})
|
||||
|
||||
it('hides and shows from a visibility setting', () => {
|
||||
const settings = mockApp(false)
|
||||
const button = new ComfyButton({
|
||||
app: settings.app,
|
||||
visibilitySetting: {
|
||||
id: 'Comfy.UseNewMenu',
|
||||
showValue: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(button.hidden).toBe(true)
|
||||
expect(button.element.classList.contains('hidden')).toBe(true)
|
||||
|
||||
settings.setSettingValue(true)
|
||||
|
||||
expect(button.hidden).toBe(false)
|
||||
expect(button.element.classList.contains('hidden')).toBe(false)
|
||||
})
|
||||
|
||||
it('toggles click popups and reflects popup open state in classes', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(fromAny(popup))
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
|
||||
expect(popup.toggle).toHaveBeenCalledOnce()
|
||||
expect(button.element.classList.contains('popup-open')).toBe(true)
|
||||
|
||||
popup.toggle()
|
||||
|
||||
expect(button.element.classList.contains('popup-closed')).toBe(true)
|
||||
})
|
||||
|
||||
it('opens hover popups while either the button or popup is hovered', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(
|
||||
fromAny(popup),
|
||||
'hover'
|
||||
)
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
expect(popup.open).toBe(true)
|
||||
popup.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(popup.open).toBe(true)
|
||||
popup.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(popup.open).toBe(false)
|
||||
})
|
||||
|
||||
it('does not click-toggle a hover popup while hovered', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(
|
||||
fromAny(popup),
|
||||
'hover'
|
||||
)
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
|
||||
expect(popup.toggle).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
284
src/scripts/ui/components/popup.test.ts
Normal file
284
src/scripts/ui/components/popup.test.ts
Normal file
@@ -0,0 +1,284 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
type ElChild = Node | string
|
||||
type ElInput = Record<string, unknown> | ElChild | ElChild[]
|
||||
|
||||
function appendChildren(element: HTMLElement, children: ElChild | ElChild[]) {
|
||||
const list = Array.isArray(children) ? children : [children]
|
||||
for (const child of list) {
|
||||
element.append(child)
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock('../../ui', () => ({
|
||||
$el: (tag: string, propsOrChildren?: ElInput, children?: ElChild[]) => {
|
||||
const [tagName, ...classes] = tag.split('.')
|
||||
const element = document.createElement(tagName)
|
||||
if (classes.length) element.classList.add(...classes)
|
||||
|
||||
if (
|
||||
propsOrChildren instanceof Node ||
|
||||
typeof propsOrChildren === 'string' ||
|
||||
Array.isArray(propsOrChildren)
|
||||
) {
|
||||
appendChildren(element, propsOrChildren)
|
||||
} else if (propsOrChildren) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
if (key === '$' && typeof value === 'function') {
|
||||
value(element)
|
||||
} else if (key === 'parent' && value instanceof HTMLElement) {
|
||||
value.append(element)
|
||||
} else if (key === 'textContent') {
|
||||
element.textContent = String(value)
|
||||
} else if (key === 'ariaLabel') {
|
||||
element.setAttribute('aria-label', String(value))
|
||||
} else if (key === 'ariaHasPopup') {
|
||||
element.setAttribute('aria-haspopup', String(value))
|
||||
} else if (key === 'type') {
|
||||
element.setAttribute('type', String(value))
|
||||
} else if (
|
||||
key.toLowerCase().startsWith('on') &&
|
||||
typeof value === 'function'
|
||||
) {
|
||||
element.addEventListener(
|
||||
key.slice(2).toLowerCase(),
|
||||
value as EventListener
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (children) appendChildren(element, children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
prop: <T>(
|
||||
target: object,
|
||||
name: string,
|
||||
defaultValue: T,
|
||||
onChanged?: (currentValue: T, previousValue: T) => void
|
||||
) => {
|
||||
let currentValue: T
|
||||
Object.defineProperty(target, name, {
|
||||
get() {
|
||||
return currentValue
|
||||
},
|
||||
set(newValue: T) {
|
||||
const previousValue = currentValue
|
||||
currentValue = newValue
|
||||
onChanged?.(currentValue, previousValue)
|
||||
}
|
||||
})
|
||||
;(target as Record<string, T>)[name] = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../utils', () => ({
|
||||
applyClasses: (
|
||||
element: HTMLElement,
|
||||
classList: string | Record<string, boolean>,
|
||||
...baseClasses: string[]
|
||||
) => {
|
||||
element.className = baseClasses.join(' ')
|
||||
if (typeof classList === 'string') {
|
||||
element.classList.add(...classList.split(' ').filter(Boolean))
|
||||
} else {
|
||||
for (const [className, enabled] of Object.entries(classList)) {
|
||||
element.classList.toggle(className, enabled)
|
||||
}
|
||||
}
|
||||
},
|
||||
toggleElement:
|
||||
<T>(
|
||||
element: HTMLElement,
|
||||
{
|
||||
onShow
|
||||
}: {
|
||||
onShow?: (element: HTMLElement, value: T) => void
|
||||
} = {}
|
||||
) =>
|
||||
(value: T) => {
|
||||
element.hidden = !value
|
||||
if (value) onShow?.(element, value)
|
||||
}
|
||||
}))
|
||||
|
||||
import { ComfyAsyncDialog } from './asyncDialog'
|
||||
import { ComfyButton } from './button'
|
||||
import { ComfyButtonGroup } from './buttonGroup'
|
||||
import { ComfyPopup } from './popup'
|
||||
import { ComfySplitButton } from './splitButton'
|
||||
|
||||
function targetWithRect(rect: DOMRect) {
|
||||
const target = document.createElement('button')
|
||||
vi.spyOn(target, 'getBoundingClientRect').mockReturnValue(rect)
|
||||
document.body.append(target)
|
||||
return target
|
||||
}
|
||||
|
||||
describe('ComfyPopup and related UI components', () => {
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('opens, positions, updates children and classes, and closes from escape', () => {
|
||||
const target = targetWithRect(
|
||||
DOMRect.fromRect({ x: 10, y: 20, width: 80, height: 30 })
|
||||
)
|
||||
const child = document.createElement('span')
|
||||
const popup = new ComfyPopup(
|
||||
{ target, classList: { menu: true, hidden: false } },
|
||||
child
|
||||
)
|
||||
const open = vi.fn()
|
||||
const close = vi.fn()
|
||||
const change = vi.fn()
|
||||
popup.addEventListener('open', open)
|
||||
popup.addEventListener('close', close)
|
||||
popup.addEventListener('change', change)
|
||||
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
|
||||
DOMRect.fromRect({ height: 20 })
|
||||
)
|
||||
|
||||
popup.open = true
|
||||
|
||||
expect(open).toHaveBeenCalledOnce()
|
||||
expect(change).toHaveBeenCalledOnce()
|
||||
expect(popup.element).toHaveClass('open')
|
||||
expect(popup.element.style.getPropertyValue('--left')).toBe('10px')
|
||||
expect(popup.element.style.getPropertyValue('--bottom')).toBe('35px')
|
||||
|
||||
const nextChild = document.createElement('strong')
|
||||
popup.children = [nextChild]
|
||||
popup.classList = 'extra'
|
||||
|
||||
expect(popup.element.firstElementChild).toBe(nextChild)
|
||||
expect(popup.element).toHaveClass('comfyui-popup', 'left', 'extra')
|
||||
|
||||
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Enter' }))
|
||||
expect(popup.open).toBe(true)
|
||||
|
||||
const escape = new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
cancelable: true
|
||||
})
|
||||
window.dispatchEvent(escape)
|
||||
|
||||
expect(close).toHaveBeenCalledOnce()
|
||||
expect(escape.defaultPrevented).toBe(true)
|
||||
expect(popup.open).toBe(false)
|
||||
})
|
||||
|
||||
it('handles outside clicks, target clicks, and relative right positioning', () => {
|
||||
const target = targetWithRect(
|
||||
DOMRect.fromRect({ x: 100, y: 40, width: 60, height: 25 })
|
||||
)
|
||||
const container = document.createElement('section')
|
||||
document.body.append(container)
|
||||
const popup = new ComfyPopup({
|
||||
target,
|
||||
container,
|
||||
position: 'relative',
|
||||
horizontal: 'right'
|
||||
})
|
||||
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
|
||||
DOMRect.fromRect({ height: 100 })
|
||||
)
|
||||
Object.defineProperty(popup.element, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 40
|
||||
})
|
||||
|
||||
popup.open = true
|
||||
target.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
expect(popup.open).toBe(true)
|
||||
|
||||
const outside = document.createElement('div')
|
||||
document.body.append(outside)
|
||||
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
|
||||
expect(popup.open).toBe(false)
|
||||
expect(popup.element.style.getPropertyValue('--left')).toBe('0px')
|
||||
expect(popup.element.style.getPropertyValue('--top')).toBe('25px')
|
||||
})
|
||||
|
||||
it('keeps outside clicks open when target clicks are not ignored', () => {
|
||||
const target = targetWithRect(DOMRect.fromRect({ height: 10 }))
|
||||
const popup = new ComfyPopup({
|
||||
target,
|
||||
ignoreTarget: false,
|
||||
closeOnEscape: false
|
||||
})
|
||||
const outside = document.createElement('div')
|
||||
document.body.append(outside)
|
||||
|
||||
popup.toggle()
|
||||
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' }))
|
||||
|
||||
expect(popup.open).toBe(true)
|
||||
})
|
||||
|
||||
it('renders split button popup items and updates button groups', () => {
|
||||
const primary = new ComfyButton({ content: 'Queue' })
|
||||
const itemButton = new ComfyButton({ content: 'Queue front' })
|
||||
const rawItem = document.createElement('button')
|
||||
rawItem.textContent = 'Queue back'
|
||||
const split = new ComfySplitButton(
|
||||
{
|
||||
primary,
|
||||
mode: 'hover',
|
||||
horizontal: 'right',
|
||||
position: 'absolute'
|
||||
},
|
||||
itemButton,
|
||||
rawItem
|
||||
)
|
||||
|
||||
expect(split.element).toHaveClass('comfyui-split-button', 'hover')
|
||||
expect(split.popup.element).toHaveTextContent('Queue front')
|
||||
expect(split.popup.element).toHaveTextContent('Queue back')
|
||||
expect(document.body).toContainElement(split.popup.element)
|
||||
|
||||
const group = new ComfyButtonGroup(primary)
|
||||
group.append(itemButton)
|
||||
group.insert(fromAny(rawItem), 1)
|
||||
|
||||
expect(group.element.children).toHaveLength(3)
|
||||
expect(group.remove(itemButton)).toEqual([itemButton])
|
||||
expect(group.remove(itemButton)).toBeUndefined()
|
||||
expect(group.element.children).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('resolves async dialogs from buttons, close events, and prompt actions', async () => {
|
||||
const dialog = new ComfyAsyncDialog<number>([
|
||||
{ text: 'Seven', value: 7 },
|
||||
'Fallback'
|
||||
])
|
||||
|
||||
const promise = dialog.show('Pick one')
|
||||
dialog.element.querySelector<HTMLButtonElement>('button')?.click()
|
||||
await expect(promise).resolves.toBe(7)
|
||||
|
||||
const closePromise = dialog.show(document.createElement('em'))
|
||||
dialog.element.dispatchEvent(new Event('close'))
|
||||
await expect(closePromise).resolves.toBeNull()
|
||||
|
||||
const promptPromise = ComfyAsyncDialog.prompt({
|
||||
title: 'Confirm',
|
||||
message: 'Continue?',
|
||||
actions: [{ text: 'Yes', value: 'yes' }]
|
||||
})
|
||||
const prompt = Array.from(document.querySelectorAll('dialog')).at(-1)
|
||||
expect(prompt).toHaveTextContent('Confirm')
|
||||
prompt?.querySelector<HTMLButtonElement>('button')?.click()
|
||||
|
||||
await expect(promptPromise).resolves.toBe('yes')
|
||||
expect(document.body).not.toContainElement(prompt ?? null)
|
||||
})
|
||||
})
|
||||
69
src/scripts/ui/dialog.test.ts
Normal file
69
src/scripts/ui/dialog.test.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ComfyDialog } from './dialog'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
$el: (
|
||||
selector: string,
|
||||
propsOrChildren?: Record<string, unknown> | Node[],
|
||||
maybeChildren?: Node[]
|
||||
) => {
|
||||
const [tag, ...classes] = selector.split('.')
|
||||
const element = document.createElement(tag || 'div')
|
||||
element.classList.add(...classes.filter(Boolean))
|
||||
const children = Array.isArray(propsOrChildren)
|
||||
? propsOrChildren
|
||||
: maybeChildren
|
||||
|
||||
if (propsOrChildren && !Array.isArray(propsOrChildren)) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
if (key === 'parent' && value instanceof Node) {
|
||||
value.appendChild(element)
|
||||
} else if (key === '$' && typeof value === 'function') {
|
||||
value(element)
|
||||
} else {
|
||||
Reflect.set(element, key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
element.append(...(children ?? []))
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
describe('ComfyDialog', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('shows string and element content and closes through the default button', () => {
|
||||
const dialog = new ComfyDialog()
|
||||
|
||||
dialog.show('<strong>Hello</strong>')
|
||||
|
||||
expect(dialog.element.style.display).toBe('flex')
|
||||
expect(dialog.textElement.innerHTML).toBe('<strong>Hello</strong>')
|
||||
|
||||
dialog.element.querySelector('button')?.click()
|
||||
expect(dialog.element.style.display).toBe('none')
|
||||
|
||||
const first = document.createElement('span')
|
||||
const second = document.createElement('em')
|
||||
dialog.show([first, second])
|
||||
|
||||
expect([...dialog.textElement.children]).toEqual([first, second])
|
||||
})
|
||||
|
||||
it('uses supplied custom buttons', () => {
|
||||
const button = document.createElement('button')
|
||||
const dialog = new ComfyDialog('section', [button])
|
||||
|
||||
expect(dialog.element.tagName).toBe('SECTION')
|
||||
expect(dialog.element.querySelector('button')).toBe(button)
|
||||
})
|
||||
})
|
||||
216
src/scripts/ui/draggableList.test.ts
Normal file
216
src/scripts/ui/draggableList.test.ts
Normal file
@@ -0,0 +1,216 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DraggableList } from './draggableList'
|
||||
|
||||
function createList(itemCount: number) {
|
||||
const container = document.createElement('div')
|
||||
const items = Array.from({ length: itemCount }, (_, index) => {
|
||||
const item = document.createElement('div')
|
||||
item.className = 'item'
|
||||
item.dataset.index = String(index)
|
||||
|
||||
const handle = document.createElement('button')
|
||||
handle.className = 'drag-handle'
|
||||
item.append(handle)
|
||||
container.append(item)
|
||||
|
||||
return item
|
||||
})
|
||||
document.body.append(container)
|
||||
return { container, items }
|
||||
}
|
||||
|
||||
function setRect(element: Element, top: number, height = 20) {
|
||||
return vi
|
||||
.spyOn(element, 'getBoundingClientRect')
|
||||
.mockReturnValue(new DOMRect(0, top, 100, height))
|
||||
}
|
||||
|
||||
function defineScrollMetrics(
|
||||
container: HTMLElement,
|
||||
scrollHeight: number,
|
||||
clientHeight: number
|
||||
) {
|
||||
Object.defineProperty(container, 'scrollHeight', {
|
||||
configurable: true,
|
||||
value: scrollHeight
|
||||
})
|
||||
Object.defineProperty(container, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: clientHeight
|
||||
})
|
||||
}
|
||||
|
||||
function mouseDragEvent(
|
||||
target: Element,
|
||||
overrides: Partial<MouseEvent> = {}
|
||||
): MouseEvent {
|
||||
return {
|
||||
button: 0,
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target,
|
||||
...overrides
|
||||
} satisfies Partial<MouseEvent> as MouseEvent
|
||||
}
|
||||
|
||||
describe('DraggableList', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('ignores missing containers and non-primary drag starts', () => {
|
||||
const listWithoutContainer = new DraggableList(null, '.item')
|
||||
const { container, items } = createList(1)
|
||||
const list = new DraggableList(container, '.item')
|
||||
|
||||
list.dragStart(
|
||||
mouseDragEvent(items[0].querySelector('.drag-handle')!, { button: 1 })
|
||||
)
|
||||
list.dragEnd()
|
||||
|
||||
expect(listWithoutContainer.listContainer).toBeNull()
|
||||
expect(list.draggableItem).toBeUndefined()
|
||||
})
|
||||
|
||||
it('starts from a handle, scrolls downward, and reorders upward', () => {
|
||||
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
const { container, items } = createList(3)
|
||||
const scrollBy = vi.fn((_left: number, top: number) => {
|
||||
container.scrollTop += top
|
||||
})
|
||||
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
|
||||
defineScrollMetrics(container, 120, 80)
|
||||
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
|
||||
new DOMRect(0, 0, 100, 80)
|
||||
)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[1], 30)
|
||||
setRect(items[2], 60)
|
||||
|
||||
const list = new DraggableList(container, '.item')
|
||||
const dragStart = vi.fn()
|
||||
const dragEnd = vi.fn()
|
||||
list.addEventListener('dragstart', dragStart)
|
||||
list.addEventListener('dragend', dragEnd)
|
||||
|
||||
list.dragStart(
|
||||
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
|
||||
clientX: 10,
|
||||
clientY: 70
|
||||
})
|
||||
)
|
||||
list.drag(
|
||||
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
|
||||
clientX: 20,
|
||||
clientY: 100
|
||||
})
|
||||
)
|
||||
items[1].dataset.isToggled = ''
|
||||
list.dragEnd()
|
||||
|
||||
expect(dragStart).toHaveBeenCalledOnce()
|
||||
expect(dragEnd).toHaveBeenCalledOnce()
|
||||
expect(scrollBy).toHaveBeenCalledWith(0, 10)
|
||||
expect([...container.children]).toEqual([items[0], items[2], items[1]])
|
||||
expect(items[0].classList.contains('is-idle')).toBe(true)
|
||||
expect(items[1].classList.contains('is-idle')).toBe(true)
|
||||
})
|
||||
|
||||
it('supports touch coordinates, upward scrolling, and downward reorder', () => {
|
||||
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
const { container, items } = createList(3)
|
||||
const scrollBy = vi.fn((_left: number, top: number) => {
|
||||
container.scrollTop += top
|
||||
})
|
||||
container.scrollTop = 10
|
||||
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
|
||||
defineScrollMetrics(container, 120, 80)
|
||||
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
|
||||
new DOMRect(0, 20, 100, 80)
|
||||
)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[1], 30)
|
||||
setRect(items[2], 60)
|
||||
|
||||
const list = new DraggableList(container, '.item')
|
||||
const touchStart = {
|
||||
button: 0,
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target: items[0].querySelector('.drag-handle')!,
|
||||
touches: [{ clientX: 5, clientY: 30 }]
|
||||
} as unknown as TouchEvent
|
||||
const touchMove = {
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target: items[0].querySelector('.drag-handle')!,
|
||||
touches: [{ clientX: 8, clientY: 10 }]
|
||||
} as unknown as TouchEvent
|
||||
|
||||
list.dragStart(touchStart)
|
||||
list.drag(touchMove)
|
||||
items[1].dataset.isToggled = ''
|
||||
list.dragEnd()
|
||||
|
||||
expect(scrollBy).toHaveBeenCalledWith(0, -10)
|
||||
expect([...container.children]).toEqual([items[1], items[0], items[2]])
|
||||
})
|
||||
|
||||
it('updates idle item state around the dragged item midpoint', () => {
|
||||
const { container, items } = createList(3)
|
||||
const list = new DraggableList(container, '.item')
|
||||
const state = list as unknown as {
|
||||
items: HTMLElement[]
|
||||
draggableItem: HTMLElement
|
||||
}
|
||||
state.items = items
|
||||
state.draggableItem = items[1]
|
||||
list.itemsGap = 5
|
||||
items[0].classList.add('is-idle')
|
||||
items[1].classList.add('is-idle')
|
||||
items[2].classList.add('is-idle')
|
||||
items[0].dataset.isAbove = ''
|
||||
const draggedRect = setRect(items[1], -10)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[2], 60)
|
||||
|
||||
list.updateIdleItemsStateAndPosition()
|
||||
|
||||
expect(items[0].dataset.isToggled).toBe('')
|
||||
expect(items[0].style.transform).toBe('translateY(25px)')
|
||||
expect(items[2].style.transform).toBe('')
|
||||
|
||||
draggedRect.mockReturnValue(new DOMRect(0, 100, 100, 20))
|
||||
list.updateIdleItemsStateAndPosition()
|
||||
|
||||
expect(items[0].dataset.isToggled).toBeUndefined()
|
||||
expect(items[2].dataset.isToggled).toBe('')
|
||||
expect(items[2].style.transform).toBe('translateY(-25px)')
|
||||
})
|
||||
|
||||
it('uses zero gap for short lists and disposes listeners', () => {
|
||||
const { container } = createList(1)
|
||||
const list = new DraggableList(container, '.item')
|
||||
const off = vi.fn()
|
||||
const disposableList = list as unknown as { off: Array<() => void> }
|
||||
disposableList.off = [off]
|
||||
|
||||
list.setItemsGap()
|
||||
list.dispose()
|
||||
|
||||
expect(list.itemsGap).toBe(0)
|
||||
expect(off).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
33
src/scripts/ui/imagePreview.test.ts
Normal file
33
src/scripts/ui/imagePreview.test.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { calculateImageGrid } from './imagePreview'
|
||||
|
||||
function createImage(width: number, height: number) {
|
||||
const img = document.createElement('img')
|
||||
Object.defineProperty(img, 'naturalWidth', {
|
||||
configurable: true,
|
||||
value: width
|
||||
})
|
||||
Object.defineProperty(img, 'naturalHeight', {
|
||||
configurable: true,
|
||||
value: height
|
||||
})
|
||||
return img
|
||||
}
|
||||
|
||||
describe('imagePreview', () => {
|
||||
it('calculates the highest-area grid', () => {
|
||||
const images = [
|
||||
createImage(100, 100),
|
||||
createImage(100, 100),
|
||||
createImage(100, 100)
|
||||
]
|
||||
|
||||
expect(calculateImageGrid(images, 300, 120)).toMatchObject({
|
||||
cellWidth: 100,
|
||||
cellHeight: 100,
|
||||
cols: 3,
|
||||
rows: 1
|
||||
})
|
||||
})
|
||||
})
|
||||
86
src/scripts/ui/toggleSwitch.test.ts
Normal file
86
src/scripts/ui/toggleSwitch.test.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { toggleSwitch } from './toggleSwitch'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
$el: (
|
||||
selector: string,
|
||||
propsOrChildren?: Record<string, unknown> | Node[] | Node,
|
||||
maybeChildren?: Node[] | Node
|
||||
) => {
|
||||
const [tag, ...classes] = selector.split('.')
|
||||
const element = document.createElement(tag || 'div')
|
||||
element.classList.add(...classes.filter(Boolean))
|
||||
const children = Array.isArray(propsOrChildren)
|
||||
? propsOrChildren
|
||||
: propsOrChildren instanceof Node
|
||||
? [propsOrChildren]
|
||||
: Array.isArray(maybeChildren)
|
||||
? maybeChildren
|
||||
: maybeChildren instanceof Node
|
||||
? [maybeChildren]
|
||||
: []
|
||||
|
||||
if (
|
||||
propsOrChildren &&
|
||||
!(propsOrChildren instanceof Node) &&
|
||||
!Array.isArray(propsOrChildren)
|
||||
) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
Reflect.set(element, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
element.append(...children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
describe('toggleSwitch', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('selects the first item when none is preselected', () => {
|
||||
const onChange = vi.fn()
|
||||
const container = toggleSwitch('mode', ['first', 'second'], { onChange })
|
||||
const labels = [...container.querySelectorAll('label')]
|
||||
const inputs = [...container.querySelectorAll('input')]
|
||||
|
||||
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(true)
|
||||
expect((inputs[0] as HTMLInputElement).checked).toBe(true)
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
item: 'first',
|
||||
prev: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('moves selection and reports the previous item', () => {
|
||||
const onChange = vi.fn()
|
||||
const container = toggleSwitch(
|
||||
'mode',
|
||||
[
|
||||
{ text: 'first', tooltip: 'First option' },
|
||||
{ text: 'second', value: '2' }
|
||||
],
|
||||
{ onChange }
|
||||
)
|
||||
const labels = [...container.querySelectorAll('label')]
|
||||
const secondInput = labels[1].querySelector('input') as HTMLInputElement
|
||||
|
||||
secondInput.onchange?.(new Event('change'))
|
||||
|
||||
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(false)
|
||||
expect(labels[1].classList.contains('comfy-toggle-selected')).toBe(true)
|
||||
expect(labels[0].title).toBe('First option')
|
||||
expect(secondInput.value).toBe('2')
|
||||
expect(onChange).toHaveBeenLastCalledWith({
|
||||
item: { text: 'second', value: '2' },
|
||||
prev: { text: 'first', tooltip: 'First option', value: 'first' }
|
||||
})
|
||||
})
|
||||
})
|
||||
45
src/scripts/ui/utils.test.ts
Normal file
45
src/scripts/ui/utils.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { applyClasses, toggleElement } from './utils'
|
||||
|
||||
describe('ui utils', () => {
|
||||
it('applies string, array, object, and required classes', () => {
|
||||
const element = document.createElement('div')
|
||||
|
||||
applyClasses(element, 'one two', 'required')
|
||||
expect([...element.classList]).toEqual(['one', 'two', 'required'])
|
||||
|
||||
applyClasses(element, ['three', 'four'])
|
||||
expect([...element.classList]).toEqual(['three', 'four'])
|
||||
|
||||
applyClasses(element, { five: true, six: false, seven: true })
|
||||
expect([...element.classList]).toEqual(['five', 'seven'])
|
||||
|
||||
applyClasses(element, null as unknown as string)
|
||||
expect(element.className).toBe('')
|
||||
})
|
||||
|
||||
it('toggles an element through a placeholder', () => {
|
||||
const parent = document.createElement('div')
|
||||
const element = document.createElement('span')
|
||||
const onHide = vi.fn()
|
||||
const onShow = vi.fn()
|
||||
parent.append(element)
|
||||
const toggle = toggleElement(element, { onHide, onShow })
|
||||
|
||||
toggle(false)
|
||||
expect(parent.firstChild).toBeInstanceOf(Comment)
|
||||
expect(onHide).toHaveBeenCalledWith(element)
|
||||
|
||||
toggle(true)
|
||||
expect(parent.firstChild).toBe(element)
|
||||
expect(onShow).toHaveBeenCalledWith(element, true)
|
||||
|
||||
toggle('visible')
|
||||
expect(onShow).toHaveBeenCalledWith(element, 'visible')
|
||||
|
||||
toggle(false)
|
||||
expect(parent.firstChild).toBeInstanceOf(Comment)
|
||||
expect(onHide).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
127
src/scripts/utils.test.ts
Normal file
127
src/scripts/utils.test.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
addStylesheet,
|
||||
clone,
|
||||
getStorageValue,
|
||||
prop,
|
||||
setStorageValue
|
||||
} from './utils'
|
||||
|
||||
interface LinkAttrs {
|
||||
href: string
|
||||
onerror: (error: Event) => void
|
||||
onload: () => void
|
||||
parent: HTMLElement
|
||||
rel: string
|
||||
type: string
|
||||
}
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
api: {
|
||||
clientId: null as string | null,
|
||||
initialClientId: null as string | null
|
||||
},
|
||||
$el: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('./api', () => ({
|
||||
api: mocks.api
|
||||
}))
|
||||
|
||||
vi.mock('./ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
function lastLinkAttrs() {
|
||||
return mocks.$el.mock.calls.at(-1)?.[1] as LinkAttrs
|
||||
}
|
||||
|
||||
describe('scripts utils', () => {
|
||||
afterEach(() => {
|
||||
localStorage.clear()
|
||||
sessionStorage.clear()
|
||||
mocks.api.clientId = null
|
||||
mocks.api.initialClientId = null
|
||||
mocks.$el.mockReset()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('clones with structuredClone and falls back to JSON cloning', () => {
|
||||
const source = { nested: { value: 1 } }
|
||||
|
||||
expect(clone(source)).toEqual(source)
|
||||
|
||||
vi.stubGlobal(
|
||||
'structuredClone',
|
||||
vi.fn(() => {
|
||||
throw new Error('unsupported')
|
||||
})
|
||||
)
|
||||
|
||||
const cloned = clone(source)
|
||||
cloned.nested.value = 2
|
||||
|
||||
expect(cloned).toEqual({ nested: { value: 2 } })
|
||||
expect(source).toEqual({ nested: { value: 1 } })
|
||||
})
|
||||
|
||||
it('adds stylesheets from script and relative URLs', async () => {
|
||||
const scriptPromise = addStylesheet('/extensions/example.js')
|
||||
lastLinkAttrs().onload()
|
||||
|
||||
await expect(scriptPromise).resolves.toBeUndefined()
|
||||
expect(lastLinkAttrs()).toMatchObject({
|
||||
href: '/extensions/example.css',
|
||||
parent: document.head,
|
||||
rel: 'stylesheet',
|
||||
type: 'text/css'
|
||||
})
|
||||
|
||||
const cssPromise = addStylesheet('theme.css', 'https://example.com/base/')
|
||||
lastLinkAttrs().onload()
|
||||
|
||||
await expect(cssPromise).resolves.toBeUndefined()
|
||||
expect(lastLinkAttrs().href).toBe('https://example.com/base/theme.css')
|
||||
})
|
||||
|
||||
it('rejects when stylesheet loading fails', async () => {
|
||||
const promise = addStylesheet('missing.css', 'https://example.com/')
|
||||
const error = new Event('error')
|
||||
lastLinkAttrs().onerror(error)
|
||||
|
||||
await expect(promise).rejects.toBe(error)
|
||||
})
|
||||
|
||||
it('defines an observable property with the supplied default', () => {
|
||||
const target = {}
|
||||
const onChanged = vi.fn()
|
||||
|
||||
expect(prop(target, 'mode', 'initial', onChanged)).toBe('initial')
|
||||
Object.assign(target, { mode: 'next' })
|
||||
|
||||
expect((target as { mode: string }).mode).toBe('next')
|
||||
expect(onChanged).toHaveBeenCalledWith('next', undefined, target, 'mode')
|
||||
})
|
||||
|
||||
it('uses client-scoped storage before local fallback', () => {
|
||||
mocks.api.clientId = 'client-1'
|
||||
setStorageValue('setting', 'client-value')
|
||||
sessionStorage.removeItem('setting:client-1')
|
||||
|
||||
expect(getStorageValue('setting')).toBe('client-value')
|
||||
expect(localStorage.getItem('setting')).toBe('client-value')
|
||||
|
||||
sessionStorage.setItem('setting:client-1', 'session-value')
|
||||
|
||||
expect(getStorageValue('setting')).toBe('session-value')
|
||||
})
|
||||
|
||||
it('uses initial client id when the current client id is unavailable', () => {
|
||||
mocks.api.initialClientId = 'initial-1'
|
||||
setStorageValue('setting', 'initial-value')
|
||||
|
||||
expect(sessionStorage.getItem('setting:initial-1')).toBe('initial-value')
|
||||
expect(getStorageValue('setting')).toBe('initial-value')
|
||||
})
|
||||
})
|
||||
356
src/scripts/widgets.test.ts
Normal file
356
src/scripts/widgets.test.ts
Normal file
@@ -0,0 +1,356 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type {
|
||||
IBaseWidget,
|
||||
IComboWidget
|
||||
} from '@/lib/litegraph/src/types/widgets'
|
||||
import type { InputSpec } from '@/schemas/nodeDefSchema'
|
||||
|
||||
const mockSettingGet = vi.hoisted(() => vi.fn())
|
||||
const mockNextValueForLinkedTarget = vi.hoisted(() => vi.fn())
|
||||
const mockIsComboWidget = vi.hoisted(() => vi.fn())
|
||||
const mockTransformInputSpecV1ToV2 = vi.hoisted(() => vi.fn())
|
||||
|
||||
function v2WidgetConstructor(kind: string) {
|
||||
return () => (_node: LGraphNode, inputSpec: { name: string }) => ({
|
||||
name: `${kind}:${inputSpec.name}`,
|
||||
options: { minNodeSize: [20, 30] }
|
||||
})
|
||||
}
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => `translated:${key}`
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: mockSettingGet
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/litegraph/src/litegraph', () => ({
|
||||
isComboWidget: mockIsComboWidget
|
||||
}))
|
||||
|
||||
vi.mock('./valueControl', () => ({
|
||||
nextValueForLinkedTarget: mockNextValueForLinkedTarget
|
||||
}))
|
||||
|
||||
vi.mock('@/schemas/nodeDef/migration', () => ({
|
||||
transformInputSpecV1ToV2: mockTransformInputSpecV1ToV2
|
||||
}))
|
||||
|
||||
vi.mock('@/core/graph/widgets/dynamicWidgets', () => ({
|
||||
dynamicWidgets: {
|
||||
DYNAMIC: () => ({
|
||||
widget: { name: 'dynamic', options: {} },
|
||||
minWidth: 1,
|
||||
minHeight: 1
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
|
||||
() => ({ useBooleanWidget: v2WidgetConstructor('BOOLEAN') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget',
|
||||
() => ({ useBoundingBoxWidget: v2WidgetConstructor('BOUNDING_BOX') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useCurveWidget',
|
||||
() => ({ useCurveWidget: v2WidgetConstructor('CURVE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useChartWidget',
|
||||
() => ({ useChartWidget: v2WidgetConstructor('CHART') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useColorWidget',
|
||||
() => ({ useColorWidget: v2WidgetConstructor('COLOR') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useComboWidget',
|
||||
() => ({ useComboWidget: v2WidgetConstructor('COMBO') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
|
||||
() => ({ useFloatWidget: v2WidgetConstructor('FLOAT') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useGalleriaWidget',
|
||||
() => ({ useGalleriaWidget: v2WidgetConstructor('GALLERIA') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxesWidget',
|
||||
() => ({ useBoundingBoxesWidget: v2WidgetConstructor('BOUNDING_BOXES') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useColorsWidget',
|
||||
() => ({ useColorsWidget: v2WidgetConstructor('COLORS') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useImageCompareWidget',
|
||||
() => ({ useImageCompareWidget: v2WidgetConstructor('IMAGECOMPARE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useImageUploadWidget',
|
||||
() => ({
|
||||
useImageUploadWidget: () => (_node: LGraphNode, inputName: string) => ({
|
||||
widget: { name: `IMAGEUPLOAD:${inputName}`, options: {} },
|
||||
minWidth: 5,
|
||||
minHeight: 6
|
||||
})
|
||||
})
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useIntWidget',
|
||||
() => ({ useIntWidget: v2WidgetConstructor('INT') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useMarkdownWidget',
|
||||
() => ({ useMarkdownWidget: v2WidgetConstructor('MARKDOWN') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/usePainterWidget',
|
||||
() => ({ usePainterWidget: v2WidgetConstructor('PAINTER') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useRangeWidget',
|
||||
() => ({ useRangeWidget: v2WidgetConstructor('RANGE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
|
||||
() => ({ useStringWidget: v2WidgetConstructor('STRING') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useTextareaWidget',
|
||||
() => ({ useTextareaWidget: v2WidgetConstructor('TEXTAREA') })
|
||||
)
|
||||
|
||||
vi.mock('./domWidget', () => ({}))
|
||||
vi.mock('./errorNodeWidgets', () => ({}))
|
||||
|
||||
import {
|
||||
ComfyWidgets,
|
||||
IS_CONTROL_WIDGET,
|
||||
addValueControlWidget,
|
||||
addValueControlWidgets,
|
||||
isValidWidgetType,
|
||||
updateControlWidgetLabel
|
||||
} from './widgets'
|
||||
|
||||
// `linkedWidgets`, `beforeQueued`, and `afterQueued` already exist on
|
||||
// IBaseWidget (via the litegraph augmentation), so no extra members needed.
|
||||
type MockWidget = IBaseWidget
|
||||
|
||||
function makeTargetWidget(overrides: Partial<MockWidget> = {}): MockWidget {
|
||||
return {
|
||||
name: 'seed',
|
||||
value: 1,
|
||||
callback: vi.fn(),
|
||||
options: {},
|
||||
linkedWidgets: [],
|
||||
computedDisabled: false,
|
||||
...overrides
|
||||
} as MockWidget
|
||||
}
|
||||
|
||||
function makeNode(inputs: LGraphNode['inputs'] = []) {
|
||||
const widgets: MockWidget[] = []
|
||||
const node = {
|
||||
id: 42,
|
||||
inputs,
|
||||
addWidget: vi.fn(
|
||||
(
|
||||
type: string,
|
||||
name: string,
|
||||
value: string,
|
||||
callback: () => void,
|
||||
options: Record<string, unknown>
|
||||
) => {
|
||||
const widget: MockWidget = fromAny({
|
||||
type,
|
||||
name,
|
||||
value,
|
||||
callback,
|
||||
options,
|
||||
linkedWidgets: [],
|
||||
computedDisabled: false
|
||||
})
|
||||
widgets.push(widget)
|
||||
return widget
|
||||
}
|
||||
)
|
||||
}
|
||||
return { node: node as unknown as LGraphNode, widgets }
|
||||
}
|
||||
|
||||
describe('widgets', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSettingGet.mockReturnValue('after')
|
||||
mockNextValueForLinkedTarget.mockReturnValue('next')
|
||||
mockIsComboWidget.mockImplementation(
|
||||
(widget: MockWidget) => widget.type === 'combo'
|
||||
)
|
||||
mockTransformInputSpecV1ToV2.mockImplementation(
|
||||
(_inputData: InputSpec, options: { name: string }) => ({
|
||||
name: options.name
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('updates the control widget label from the configured run mode', () => {
|
||||
const widget = makeTargetWidget()
|
||||
|
||||
mockSettingGet.mockReturnValue('before')
|
||||
updateControlWidgetLabel(widget)
|
||||
expect(widget.label).toBe('translated:g.control_before_generate')
|
||||
|
||||
mockSettingGet.mockReturnValue('after')
|
||||
updateControlWidgetLabel(widget)
|
||||
expect(widget.label).toBe('translated:g.control_after_generate')
|
||||
})
|
||||
|
||||
it('adds control and filter widgets for combo targets', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget({ type: 'combo', computedDisabled: true })
|
||||
|
||||
const result = addValueControlWidgets(node, target, '', undefined, [
|
||||
'COMBO',
|
||||
{
|
||||
control_prefix: 'custom'
|
||||
}
|
||||
] as unknown as InputSpec)
|
||||
|
||||
expect(result).toHaveLength(2)
|
||||
expect(widgets[0].name).toBe('custom control_after_generate')
|
||||
expect(widgets[0].value).toBe('randomize')
|
||||
expect((widgets[0] as IComboWidget).options.values).toContain(
|
||||
'increment-wrap'
|
||||
)
|
||||
expect(widgets[0][IS_CONTROL_WIDGET]).toBe(true)
|
||||
expect(widgets[0].disabled).toBe(true)
|
||||
expect(widgets[1].name).toBe('custom control_filter_list')
|
||||
expect(widgets[1].disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('uses explicit option names and can skip the combo filter widget', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget({ type: 'combo' })
|
||||
|
||||
addValueControlWidgets(
|
||||
node,
|
||||
target,
|
||||
'fixed',
|
||||
{
|
||||
addFilterList: false,
|
||||
controlAfterGenerateName: 'mode'
|
||||
},
|
||||
['COMBO', {}] as unknown as InputSpec
|
||||
)
|
||||
|
||||
expect(widgets).toHaveLength(1)
|
||||
expect(widgets[0].name).toBe('mode')
|
||||
})
|
||||
|
||||
it('applies linked target values after queueing in after mode', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].afterQueued?.({ isPartialExecution: true })
|
||||
|
||||
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith({
|
||||
target,
|
||||
linkedWidgets: target.linkedWidgets,
|
||||
nodeId: 42,
|
||||
isPartialExecution: true
|
||||
})
|
||||
expect(target.value).toBe('next')
|
||||
expect(target.callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
|
||||
it('waits until the second beforeQueued call in before mode', () => {
|
||||
mockSettingGet.mockReturnValue('before')
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].beforeQueued?.()
|
||||
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
|
||||
|
||||
widgets[0].beforeQueued?.({ isPartialExecution: false })
|
||||
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ isPartialExecution: false })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not change the target when the target has a linked input or no next value', () => {
|
||||
const { node, widgets } = makeNode([
|
||||
{ widget: { name: 'seed' }, link: 1 }
|
||||
] as LGraphNode['inputs'])
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].afterQueued?.()
|
||||
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
|
||||
|
||||
const unlinked = makeNode()
|
||||
mockNextValueForLinkedTarget.mockReturnValue(undefined)
|
||||
addValueControlWidgets(unlinked.node, target)
|
||||
unlinked.widgets[0].afterQueued?.()
|
||||
expect(target.callback).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses the legacy single control widget name from input data before widgetName', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
const result = addValueControlWidget(
|
||||
node,
|
||||
target,
|
||||
'fixed',
|
||||
undefined,
|
||||
'fallback',
|
||||
[
|
||||
'INT',
|
||||
{
|
||||
control_after_generate: 'from_input_data'
|
||||
}
|
||||
] as unknown as InputSpec
|
||||
)
|
||||
|
||||
expect(result).toBe(widgets[0])
|
||||
expect(widgets[0].name).toBe('from_input_data')
|
||||
})
|
||||
|
||||
it('exposes transformed widget constructors and type validation', () => {
|
||||
const { node } = makeNode()
|
||||
|
||||
const intWidget = ComfyWidgets.INT(
|
||||
node,
|
||||
'value',
|
||||
['INT', {}] as unknown as InputSpec,
|
||||
{} as never
|
||||
)
|
||||
|
||||
expect(intWidget.widget.name).toBe('INT:value')
|
||||
expect(intWidget.minWidth).toBe(20)
|
||||
expect(intWidget.minHeight).toBe(30)
|
||||
expect(
|
||||
ComfyWidgets.IMAGEUPLOAD(node, 'image', ['IMAGE', {}], {} as never)
|
||||
).toMatchObject({
|
||||
widget: { name: 'IMAGEUPLOAD:image' },
|
||||
minWidth: 5,
|
||||
minHeight: 6
|
||||
})
|
||||
expect(isValidWidgetType('INT')).toBe(true)
|
||||
expect(isValidWidgetType('DYNAMIC')).toBe(true)
|
||||
expect(isValidWidgetType('missing')).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -38,6 +38,12 @@ describe('useAudioService', () => {
|
||||
name: 'test-audio-123.wav'
|
||||
}
|
||||
|
||||
async function freshService() {
|
||||
vi.resetModules()
|
||||
const audioServiceModule = await import('@/services/audioService')
|
||||
return audioServiceModule.useAudioService()
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
@@ -90,12 +96,41 @@ describe('useAudioService', () => {
|
||||
)
|
||||
mockRegister.mockRejectedValueOnce(error)
|
||||
|
||||
await service.registerWavEncoder()
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(mockConnect).toHaveBeenCalledTimes(0)
|
||||
expect(mockRegister).toHaveBeenCalledTimes(0)
|
||||
expect(mockConnect).toHaveBeenCalledTimes(1)
|
||||
expect(mockRegister).toHaveBeenCalledTimes(1)
|
||||
expect(console.error).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should log encoder registration errors', async () => {
|
||||
const error = new Error('Encoder failed')
|
||||
mockRegister.mockRejectedValueOnce(error)
|
||||
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Audio Service Error (encoder):',
|
||||
'Failed to register WAV encoder',
|
||||
error
|
||||
)
|
||||
})
|
||||
|
||||
it('should log non-Error encoder registration failures', async () => {
|
||||
mockRegister.mockRejectedValueOnce('Encoder failed')
|
||||
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Audio Service Error (encoder):',
|
||||
'Failed to register WAV encoder',
|
||||
'Encoder failed'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('stopAllTracks', () => {
|
||||
|
||||
118
src/services/autoQueueService.test.ts
Normal file
118
src/services/autoQueueService.test.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { setupAutoQueueHandler } from '@/services/autoQueueService'
|
||||
|
||||
type ApiEvent = 'graphChanged'
|
||||
type ApiListener = () => void
|
||||
type Subscription = () => Promise<void> | void
|
||||
|
||||
const {
|
||||
listeners,
|
||||
queueCountStore,
|
||||
queueSettingsStore,
|
||||
appState,
|
||||
addEventListener,
|
||||
isInstantRunningMode
|
||||
} = vi.hoisted(() => ({
|
||||
listeners: new Map<ApiEvent, ApiListener>(),
|
||||
queueCountStore: {
|
||||
count: 0,
|
||||
subscription: undefined as Subscription | undefined,
|
||||
$subscribe: vi.fn((_callback: Subscription) => {
|
||||
queueCountStore.subscription = _callback
|
||||
})
|
||||
},
|
||||
queueSettingsStore: {
|
||||
mode: 'manual',
|
||||
batchCount: 1
|
||||
},
|
||||
appState: {
|
||||
lastExecutionError: null as unknown,
|
||||
queuePrompt: vi.fn()
|
||||
},
|
||||
addEventListener: vi.fn((event: ApiEvent, listener: ApiListener) => {
|
||||
listeners.set(event, listener)
|
||||
}),
|
||||
isInstantRunningMode: vi.fn((mode: string) => mode === 'instant')
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: { addEventListener }
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: appState
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/queueStore', () => ({
|
||||
isInstantRunningMode,
|
||||
useQueuePendingTaskCountStore: () => queueCountStore,
|
||||
useQueueSettingsStore: () => queueSettingsStore
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
listeners.clear()
|
||||
queueCountStore.count = 0
|
||||
queueCountStore.subscription = undefined
|
||||
queueCountStore.$subscribe.mockClear()
|
||||
queueSettingsStore.mode = 'manual'
|
||||
queueSettingsStore.batchCount = 1
|
||||
appState.lastExecutionError = null
|
||||
appState.queuePrompt.mockReset().mockResolvedValue(undefined)
|
||||
addEventListener.mockClear()
|
||||
isInstantRunningMode
|
||||
.mockClear()
|
||||
.mockImplementation((mode) => mode === 'instant')
|
||||
})
|
||||
|
||||
describe('setupAutoQueueHandler', () => {
|
||||
it('queues immediately on graph changes when change mode is idle', () => {
|
||||
queueSettingsStore.mode = 'change'
|
||||
queueSettingsStore.batchCount = 3
|
||||
|
||||
setupAutoQueueHandler()
|
||||
listeners.get('graphChanged')?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 3)
|
||||
})
|
||||
|
||||
it('queues after pending work drains in instant mode', async () => {
|
||||
queueSettingsStore.mode = 'instant'
|
||||
queueSettingsStore.batchCount = 2
|
||||
queueCountStore.count = 0
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 2)
|
||||
})
|
||||
|
||||
it('queues after a changed graph drains from an active queue', async () => {
|
||||
queueSettingsStore.mode = 'change'
|
||||
queueCountStore.count = 1
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
listeners.get('graphChanged')?.()
|
||||
expect(appState.queuePrompt).not.toHaveBeenCalled()
|
||||
|
||||
queueCountStore.count = 0
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('does not requeue while work remains or the last run failed', async () => {
|
||||
queueSettingsStore.mode = 'instant'
|
||||
queueCountStore.count = 1
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
appState.lastExecutionError = { message: 'failed' }
|
||||
queueCountStore.count = 0
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
363
src/services/colorPaletteService.test.ts
Normal file
363
src/services/colorPaletteService.test.ts
Normal file
@@ -0,0 +1,363 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DEFAULT_DARK_COLOR_PALETTE } from '@/constants/coreColorPalettes'
|
||||
import {
|
||||
LGraphCanvas,
|
||||
LiteGraph,
|
||||
RenderShape
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { CompletedPalette, Palette } from '@/schemas/colorPaletteSchema'
|
||||
|
||||
const mockCanvas = vi.hoisted(() => ({
|
||||
default_connection_color_byType: {} as Record<string, string>,
|
||||
node_title_color: '',
|
||||
default_link_color: '',
|
||||
background_image: '',
|
||||
clear_background_color: '',
|
||||
_pattern: 'pattern' as string | undefined,
|
||||
setDirty: vi.fn()
|
||||
}))
|
||||
|
||||
const mockColorPaletteStore = vi.hoisted(() => ({
|
||||
customPalettes: {} as Record<string, unknown>,
|
||||
palettesLookup: {} as Record<string, unknown>,
|
||||
completedActivePalette: undefined as unknown,
|
||||
activePaletteId: 'dark',
|
||||
addCustomPalette: vi.fn(),
|
||||
deleteCustomPalette: vi.fn(),
|
||||
completePalette: vi.fn()
|
||||
}))
|
||||
|
||||
const mockSettingStore = vi.hoisted(() => ({
|
||||
get: vi.fn(),
|
||||
set: vi.fn()
|
||||
}))
|
||||
|
||||
const mockNodeDefStore = vi.hoisted(() => ({
|
||||
nodeDataTypes: new Set(['IMAGE', 'MISSING'])
|
||||
}))
|
||||
|
||||
const mockDownloadBlob = vi.hoisted(() => vi.fn())
|
||||
const mockUploadFile = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: { canvas: mockCanvas }
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/workspace/colorPaletteStore', () => ({
|
||||
useColorPaletteStore: () => mockColorPaletteStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => mockSettingStore
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/nodeDefStore', () => ({
|
||||
useNodeDefStore: () => mockNodeDefStore
|
||||
}))
|
||||
|
||||
vi.mock('@/base/common/downloadUtil', () => ({
|
||||
downloadBlob: mockDownloadBlob
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/utils', () => ({
|
||||
uploadFile: mockUploadFile
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useErrorHandling', () => ({
|
||||
useErrorHandling: () => ({
|
||||
wrapWithErrorHandling: <T>(action: T) => action,
|
||||
wrapWithErrorHandlingAsync: <T>(action: T) => action
|
||||
})
|
||||
}))
|
||||
|
||||
import { useColorPaletteService } from './colorPaletteService'
|
||||
|
||||
const validCustomPalette = {
|
||||
id: 'custom',
|
||||
name: 'Custom',
|
||||
colors: {
|
||||
node_slot: {},
|
||||
litegraph_base: {},
|
||||
comfy_base: {}
|
||||
}
|
||||
} satisfies Palette
|
||||
|
||||
function makeCompletedPalette(id = 'custom'): CompletedPalette {
|
||||
const palette = structuredClone(
|
||||
DEFAULT_DARK_COLOR_PALETTE
|
||||
) as CompletedPalette
|
||||
palette.id = id
|
||||
palette.name = 'Custom'
|
||||
palette.colors.node_slot.IMAGE = '#123456'
|
||||
palette.colors.litegraph_base.NODE_TITLE_COLOR = '#abcdef'
|
||||
palette.colors.litegraph_base.LINK_COLOR = '#fedcba'
|
||||
palette.colors.litegraph_base.BACKGROUND_IMAGE = 'grid.png'
|
||||
palette.colors.litegraph_base.CLEAR_BACKGROUND_COLOR = '#010203'
|
||||
palette.colors.litegraph_base.NODE_DEFAULT_SHAPE = 'legacy'
|
||||
palette.colors.comfy_base['fg-color'] = '#111111'
|
||||
palette.colors.comfy_base['bg-color'] = '#222222'
|
||||
delete palette.colors.comfy_base['contrast-mix-color']
|
||||
return palette
|
||||
}
|
||||
|
||||
describe('useColorPaletteService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCanvas.default_connection_color_byType = {}
|
||||
mockCanvas.node_title_color = ''
|
||||
mockCanvas.default_link_color = ''
|
||||
mockCanvas.background_image = ''
|
||||
mockCanvas.clear_background_color = ''
|
||||
mockCanvas._pattern = 'pattern'
|
||||
LGraphCanvas.link_type_colors = {}
|
||||
mockSettingStore.get.mockReturnValue('')
|
||||
mockSettingStore.set.mockResolvedValue(undefined)
|
||||
mockColorPaletteStore.customPalettes = { custom: validCustomPalette }
|
||||
mockColorPaletteStore.palettesLookup = { custom: validCustomPalette }
|
||||
mockColorPaletteStore.completedActivePalette = makeCompletedPalette()
|
||||
mockColorPaletteStore.activePaletteId = 'dark'
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
makeCompletedPalette()
|
||||
)
|
||||
document.documentElement.style.cssText = ''
|
||||
document.documentElement.style.setProperty(
|
||||
'--color-datatype-MISSING',
|
||||
'#ffffff'
|
||||
)
|
||||
})
|
||||
|
||||
it('adds valid custom palettes and persists the custom palette map', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.addCustomColorPalette(validCustomPalette)
|
||||
|
||||
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
|
||||
validCustomPalette
|
||||
)
|
||||
expect(mockSettingStore.set).toHaveBeenCalledWith(
|
||||
'Comfy.CustomColorPalettes',
|
||||
mockColorPaletteStore.customPalettes
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects invalid custom palettes before mutating the store', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await expect(service.addCustomColorPalette({} as Palette)).rejects.toThrow(
|
||||
'Invalid color palette against zod schema'
|
||||
)
|
||||
expect(mockColorPaletteStore.addCustomPalette).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('deletes custom palettes and persists the custom palette map', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.deleteCustomColorPalette('custom')
|
||||
|
||||
expect(mockColorPaletteStore.deleteCustomPalette).toHaveBeenCalledWith(
|
||||
'custom'
|
||||
)
|
||||
expect(mockSettingStore.set).toHaveBeenCalledWith(
|
||||
'Comfy.CustomColorPalettes',
|
||||
mockColorPaletteStore.customPalettes
|
||||
)
|
||||
})
|
||||
|
||||
it('loads palette colors into litegraph, Vue CSS variables, and canvas state', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.default_connection_color_byType.IMAGE).toBe('#123456')
|
||||
expect(LGraphCanvas.link_type_colors.IMAGE).toBe('#123456')
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--color-datatype-IMAGE')
|
||||
).toBe('#123456')
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue(
|
||||
'--color-datatype-MISSING'
|
||||
)
|
||||
).toBe('')
|
||||
expect(mockCanvas.node_title_color).toBe('#abcdef')
|
||||
expect(mockCanvas.default_link_color).toBe('#fedcba')
|
||||
expect(mockCanvas.background_image).toBe('grid.png')
|
||||
expect(mockCanvas.clear_background_color).toBe('#010203')
|
||||
expect(mockCanvas._pattern).toBeUndefined()
|
||||
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.ROUND)
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
`litegraph_base.NODE_DEFAULT_SHAPE only accepts [${[
|
||||
RenderShape.BOX,
|
||||
RenderShape.ROUND,
|
||||
RenderShape.CARD
|
||||
].join(', ')}] but got legacy`
|
||||
)
|
||||
expect(document.documentElement.style.getPropertyValue('--fg-color')).toBe(
|
||||
'#111111'
|
||||
)
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--contrast-mix-color')
|
||||
).toBe('var(--palette-contrast-mix-color)')
|
||||
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
|
||||
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
|
||||
})
|
||||
|
||||
it('skips absent palette sections while still activating the palette', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
fromAny<CompletedPalette, unknown>({
|
||||
...completedPalette,
|
||||
colors: {
|
||||
node_slot: undefined,
|
||||
litegraph_base: completedPalette.colors.litegraph_base,
|
||||
comfy_base: undefined
|
||||
}
|
||||
})
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.node_title_color).toBe('#abcdef')
|
||||
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
|
||||
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
|
||||
})
|
||||
|
||||
it('removes Vue node theme overrides for built-in palettes', async () => {
|
||||
mockColorPaletteStore.palettesLookup = { dark: validCustomPalette }
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
makeCompletedPalette('dark')
|
||||
)
|
||||
document.documentElement.style.setProperty(
|
||||
'--component-node-border',
|
||||
'#ffffff'
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('dark')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--component-node-border')
|
||||
).toBe('')
|
||||
})
|
||||
|
||||
it('removes Vue node theme variables when completed palette values are absent', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
// NODE_BOX_OUTLINE_COLOR is required on the completed palette type; the
|
||||
// test needs it absent, so delete via Reflect to keep the type intact.
|
||||
Reflect.deleteProperty(
|
||||
completedPalette.colors.litegraph_base,
|
||||
'NODE_BOX_OUTLINE_COLOR'
|
||||
)
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
document.documentElement.style.setProperty(
|
||||
'--component-node-border',
|
||||
'#ffffff'
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--component-node-border')
|
||||
).toBe('')
|
||||
})
|
||||
|
||||
it('preserves numeric LiteGraph node shapes without warning', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.colors.litegraph_base.NODE_DEFAULT_SHAPE = RenderShape.CARD
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.CARD)
|
||||
expect(warn).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses explicit optional comfy color values when present', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.colors.comfy_base['contrast-mix-color'] = '#333333'
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--contrast-mix-color')
|
||||
).toBe('#333333')
|
||||
})
|
||||
|
||||
it('uses a white splash background for light themes', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.light_theme = true
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(localStorage.getItem('comfy-splash-bg')).toBe('#FFFFFF')
|
||||
expect(localStorage.getItem('comfy-splash-fg')).toBe('#111111')
|
||||
})
|
||||
|
||||
it('uses transparent canvas background and bg image CSS when a background image setting exists', async () => {
|
||||
mockSettingStore.get.mockReturnValue('/custom/background.png')
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.clear_background_color).toBe('transparent')
|
||||
expect(document.documentElement.style.getPropertyValue('--bg-img')).toBe(
|
||||
"url('/custom/background.png')"
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when loading or exporting an unknown palette', async () => {
|
||||
mockColorPaletteStore.palettesLookup = {}
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await expect(service.loadColorPalette('missing')).rejects.toThrow(
|
||||
'Color palette missing not found'
|
||||
)
|
||||
expect(() => service.exportColorPalette('missing')).toThrow(
|
||||
'Color palette missing not found'
|
||||
)
|
||||
})
|
||||
|
||||
it('exports palette JSON by id', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
service.exportColorPalette('custom')
|
||||
|
||||
expect(mockDownloadBlob).toHaveBeenCalledOnce()
|
||||
const [filename, blob] = mockDownloadBlob.mock.calls[0] as [string, Blob]
|
||||
expect(filename).toBe('custom.json')
|
||||
await expect(blob.text()).resolves.toContain('"id": "custom"')
|
||||
})
|
||||
|
||||
it('imports palette JSON through the custom palette path', async () => {
|
||||
mockUploadFile.mockResolvedValue({
|
||||
text: () => Promise.resolve(JSON.stringify(validCustomPalette))
|
||||
})
|
||||
const service = useColorPaletteService()
|
||||
|
||||
const palette = await service.importColorPalette()
|
||||
|
||||
expect(mockUploadFile).toHaveBeenCalledWith('application/json')
|
||||
expect(palette).toEqual(validCustomPalette)
|
||||
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
|
||||
validCustomPalette
|
||||
)
|
||||
})
|
||||
|
||||
it('returns the completed active palette from the store', () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
expect(service.getActiveColorPalette()).toBe(
|
||||
mockColorPaletteStore.completedActivePalette
|
||||
)
|
||||
})
|
||||
})
|
||||
324
src/services/extensionService.test.ts
Normal file
324
src/services/extensionService.test.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { AuthUserInfo } from '@/types/authTypes'
|
||||
import type { ComfyExtension } from '@/types/comfy'
|
||||
import { useExtensionService } from './extensionService'
|
||||
|
||||
const mockLoadDisabledExtensionNames = vi.hoisted(() => vi.fn())
|
||||
const mockRegisterExtension = vi.hoisted(() => vi.fn())
|
||||
const mockCaptureCoreExtensions = vi.hoisted(() => vi.fn())
|
||||
const mockEnabledExtensions = vi.hoisted(() => ({
|
||||
value: [] as ComfyExtension[]
|
||||
}))
|
||||
vi.mock('@/stores/extensionStore', () => ({
|
||||
useExtensionStore: () => ({
|
||||
loadDisabledExtensionNames: mockLoadDisabledExtensionNames,
|
||||
registerExtension: mockRegisterExtension,
|
||||
captureCoreExtensions: mockCaptureCoreExtensions,
|
||||
get enabledExtensions() {
|
||||
return mockEnabledExtensions.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const mockGetSetting = vi.hoisted(() => vi.fn())
|
||||
const mockAddSetting = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: mockGetSetting,
|
||||
addSetting: mockAddSetting
|
||||
})
|
||||
}))
|
||||
|
||||
const mockAddDefaultKeybinding = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/keybindings/keybindingStore', () => ({
|
||||
useKeybindingStore: () => ({
|
||||
addDefaultKeybinding: mockAddDefaultKeybinding
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/keybindings/keybinding', () => ({
|
||||
KeybindingImpl: class KeybindingImpl {
|
||||
constructor(readonly source: unknown) {}
|
||||
}
|
||||
}))
|
||||
|
||||
const mockLoadExtensionCommands = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({
|
||||
loadExtensionCommands: mockLoadExtensionCommands
|
||||
})
|
||||
}))
|
||||
|
||||
const mockLoadExtensionMenuCommands = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/menuItemStore', () => ({
|
||||
useMenuItemStore: () => ({
|
||||
loadExtensionMenuCommands: mockLoadExtensionMenuCommands
|
||||
})
|
||||
}))
|
||||
|
||||
const mockRegisterBottomPanelTabs = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/workspace/bottomPanelStore', () => ({
|
||||
useBottomPanelStore: () => ({
|
||||
registerExtensionBottomPanelTabs: mockRegisterBottomPanelTabs
|
||||
})
|
||||
}))
|
||||
|
||||
const mockRegisterCustomWidgets = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => ({
|
||||
registerCustomWidgets: mockRegisterCustomWidgets
|
||||
})
|
||||
}))
|
||||
|
||||
const mockToastErrorHandler = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/composables/useErrorHandling', () => ({
|
||||
useErrorHandling: () => ({
|
||||
wrapWithErrorHandling:
|
||||
<Args extends unknown[], Return>(fn: (...args: Args) => Return) =>
|
||||
(...args: Args) =>
|
||||
fn(...args),
|
||||
wrapWithErrorHandlingAsync:
|
||||
<Args extends unknown[], Return>(
|
||||
fn: (...args: Args) => Return | Promise<Return>,
|
||||
handler: (error: unknown) => void
|
||||
) =>
|
||||
async (...args: Args) => {
|
||||
try {
|
||||
return await fn(...args)
|
||||
} catch (error) {
|
||||
handler(error)
|
||||
}
|
||||
},
|
||||
toastErrorHandler: mockToastErrorHandler
|
||||
})
|
||||
}))
|
||||
|
||||
const mockUserResolvedCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<(user: AuthUserInfo) => void>
|
||||
}))
|
||||
const mockTokenRefreshedCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<() => void>
|
||||
}))
|
||||
const mockUserLogoutCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<() => void>
|
||||
}))
|
||||
vi.mock('@/composables/auth/useCurrentUser', () => ({
|
||||
useCurrentUser: () => ({
|
||||
onUserResolved: (callback: (user: AuthUserInfo) => void) => {
|
||||
mockUserResolvedCallbacks.values.push(callback)
|
||||
},
|
||||
onTokenRefreshed: (callback: () => void) => {
|
||||
mockTokenRefreshedCallbacks.values.push(callback)
|
||||
},
|
||||
onUserLogout: (callback: () => void) => {
|
||||
mockUserLogoutCallbacks.values.push(callback)
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const mockSetCurrentExtension = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/lib/litegraph/src/contextMenuCompat', () => ({
|
||||
legacyMenuCompat: {
|
||||
setCurrentExtension: mockSetCurrentExtension
|
||||
}
|
||||
}))
|
||||
|
||||
const mockApp = vi.hoisted(() => ({ value: { name: 'app' } }))
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: mockApp.value
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
getExtensions: vi.fn(),
|
||||
fileURL: vi.fn((path: string) => path)
|
||||
}
|
||||
}))
|
||||
|
||||
describe('useExtensionService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockEnabledExtensions.value = []
|
||||
mockUserResolvedCallbacks.values = []
|
||||
mockTokenRefreshedCallbacks.values = []
|
||||
mockUserLogoutCallbacks.values = []
|
||||
})
|
||||
|
||||
it('registers extension contributions across stores', async () => {
|
||||
const widgets = { CustomWidget: vi.fn() }
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'registration-extension',
|
||||
keybindings: [{ commandId: 'command.one', combo: { key: 'K' } }],
|
||||
commands: [{ id: 'command.one', label: 'Command One' }],
|
||||
menuCommands: [{ path: ['File'], commands: ['command.one'] }],
|
||||
settings: [{ id: 'setting.one', name: 'Setting One' }],
|
||||
bottomPanelTabs: [{ id: 'tab.one', title: 'Tab One' }],
|
||||
getCustomWidgets: vi.fn().mockResolvedValue(widgets)
|
||||
})
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
|
||||
expect(mockRegisterExtension).toHaveBeenCalledWith(extension)
|
||||
expect(mockAddDefaultKeybinding).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
source: { commandId: 'command.one', combo: { key: 'K' } }
|
||||
})
|
||||
)
|
||||
expect(mockLoadExtensionCommands).toHaveBeenCalledWith(extension)
|
||||
expect(mockLoadExtensionMenuCommands).toHaveBeenCalledWith(extension)
|
||||
expect(mockAddSetting.mock.calls[0][0]).toEqual({
|
||||
id: 'setting.one',
|
||||
name: 'Setting One'
|
||||
})
|
||||
expect(mockRegisterBottomPanelTabs).toHaveBeenCalledWith(extension)
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRegisterCustomWidgets).toHaveBeenCalledWith(widgets)
|
||||
})
|
||||
})
|
||||
|
||||
it('invokes auth lifecycle hooks through registered callbacks', async () => {
|
||||
const onAuthUserResolved = vi.fn()
|
||||
const onAuthTokenRefreshed = vi.fn()
|
||||
const onAuthUserLogout = vi.fn()
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'auth-extension',
|
||||
onAuthUserResolved,
|
||||
onAuthTokenRefreshed,
|
||||
onAuthUserLogout
|
||||
})
|
||||
const user = fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
mockUserResolvedCallbacks.values[0](user)
|
||||
mockTokenRefreshedCallbacks.values[0]()
|
||||
mockUserLogoutCallbacks.values[0]()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAuthUserResolved).toHaveBeenCalledWith(user, mockApp.value)
|
||||
expect(onAuthTokenRefreshed).toHaveBeenCalled()
|
||||
expect(onAuthUserLogout).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('reports auth hook errors through the toast handler', async () => {
|
||||
const error = new Error('auth failed')
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-auth-extension',
|
||||
onAuthUserResolved: vi.fn(() => {
|
||||
throw error
|
||||
})
|
||||
})
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
mockUserResolvedCallbacks.values[0](
|
||||
fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastErrorHandler).toHaveBeenCalledWith(error)
|
||||
})
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
'[Extension Auth Hook Error]',
|
||||
expect.objectContaining({
|
||||
extension: 'failing-auth-extension',
|
||||
hook: 'onAuthUserResolved',
|
||||
error
|
||||
})
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('invokes synchronous extension methods and keeps failures isolated', () => {
|
||||
const getSelectionToolboxCommands = vi.fn(() => ['command.one'])
|
||||
const failingGetSelectionToolboxCommands = vi.fn(() => {
|
||||
throw new Error('menu failed')
|
||||
})
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockEnabledExtensions.value = [
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'working-extension',
|
||||
getSelectionToolboxCommands
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'non-function-extension',
|
||||
getSelectionToolboxCommands: ['not callable']
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-extension',
|
||||
getSelectionToolboxCommands: failingGetSelectionToolboxCommands
|
||||
}),
|
||||
{ name: 'missing-method-extension' }
|
||||
]
|
||||
const service = useExtensionService()
|
||||
|
||||
const results = service.invokeExtensions(
|
||||
'getSelectionToolboxCommands',
|
||||
fromAny<LGraphNode, unknown>({ id: 1 })
|
||||
)
|
||||
|
||||
expect(results).toEqual([['command.one']])
|
||||
expect(getSelectionToolboxCommands).toHaveBeenCalledWith(
|
||||
fromAny<LGraphNode, unknown>({ id: 1 }),
|
||||
mockApp.value
|
||||
)
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Error calling extension 'failing-extension' method 'getSelectionToolboxCommands'",
|
||||
expect.objectContaining({ error: expect.any(Error) }),
|
||||
expect.objectContaining({
|
||||
extension: expect.objectContaining({ name: 'failing-extension' })
|
||||
}),
|
||||
expect.objectContaining({
|
||||
args: [fromAny<LGraphNode, unknown>({ id: 1 })]
|
||||
})
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('tracks current extension around async setup callbacks', async () => {
|
||||
const setup = vi.fn().mockResolvedValue('setup-result')
|
||||
const failingSetup = vi.fn().mockRejectedValue(new Error('setup failed'))
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockEnabledExtensions.value = [
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'setup-extension',
|
||||
setup
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'non-function-extension',
|
||||
setup: true
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-setup-extension',
|
||||
setup: failingSetup
|
||||
}),
|
||||
{ name: 'missing-method-extension' }
|
||||
]
|
||||
const service = useExtensionService()
|
||||
|
||||
const results = await service.invokeExtensionsAsync('setup')
|
||||
|
||||
expect(results).toEqual(['setup-result', undefined, undefined, undefined])
|
||||
expect(mockSetCurrentExtension.mock.calls.map((call) => call[0])).toEqual([
|
||||
'setup-extension',
|
||||
'failing-setup-extension',
|
||||
null,
|
||||
null
|
||||
])
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Error calling extension 'failing-setup-extension' method 'setup'",
|
||||
expect.objectContaining({ error: expect.any(Error) }),
|
||||
expect.objectContaining({
|
||||
extension: expect.objectContaining({ name: 'failing-setup-extension' })
|
||||
}),
|
||||
expect.objectContaining({ args: [] })
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,35 +1,181 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useMediaCache } from './mediaCacheService'
|
||||
import type { useMediaCache } from './mediaCacheService'
|
||||
|
||||
// Mock fetch
|
||||
global.fetch = vi.fn()
|
||||
global.URL = {
|
||||
createObjectURL: vi.fn(() => 'blob:mock-url'),
|
||||
revokeObjectURL: vi.fn()
|
||||
} as Partial<typeof URL> as typeof URL
|
||||
type MediaCache = ReturnType<typeof useMediaCache>
|
||||
|
||||
describe('mediaCacheService', () => {
|
||||
describe('URL reference counting', () => {
|
||||
it('should handle URL acquisition for non-existent cache entry', () => {
|
||||
const { acquireUrl } = useMediaCache()
|
||||
const mockFetch = vi.fn()
|
||||
const mockCreateObjectURL = vi.fn()
|
||||
const mockRevokeObjectURL = vi.fn()
|
||||
const NativeURL = URL
|
||||
|
||||
const url = acquireUrl('non-existent.jpg')
|
||||
expect(url).toBeUndefined()
|
||||
class MockURL extends NativeURL {
|
||||
static override createObjectURL(blob: Blob): string {
|
||||
return mockCreateObjectURL(blob)
|
||||
}
|
||||
|
||||
static override revokeObjectURL(url: string): void {
|
||||
mockRevokeObjectURL(url)
|
||||
}
|
||||
}
|
||||
|
||||
function response(ok: boolean, blob = new Blob(['image'])): Response {
|
||||
return {
|
||||
ok,
|
||||
status: ok ? 200 : 404,
|
||||
blob: () => Promise.resolve(blob)
|
||||
} as Response
|
||||
}
|
||||
|
||||
async function freshCache(options?: {
|
||||
maxSize?: number
|
||||
maxAge?: number
|
||||
}): Promise<MediaCache> {
|
||||
vi.resetModules()
|
||||
const { useMediaCache } = await import('./mediaCacheService')
|
||||
return useMediaCache(options)
|
||||
}
|
||||
|
||||
describe('useMediaCache', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.setSystemTime(0)
|
||||
mockFetch.mockReset()
|
||||
mockCreateObjectURL.mockReset()
|
||||
mockRevokeObjectURL.mockReset()
|
||||
mockCreateObjectURL.mockImplementation(
|
||||
(_blob: Blob) => `blob:${mockCreateObjectURL.mock.calls.length}`
|
||||
)
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
vi.stubGlobal('URL', MockURL)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
window.dispatchEvent(new Event('beforeunload'))
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('fetches media once and returns cached entries on later reads', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache()
|
||||
|
||||
const first = await cache.getCachedMedia('/image.png')
|
||||
vi.setSystemTime(100)
|
||||
const second = await cache.getCachedMedia('/image.png')
|
||||
|
||||
expect(first).toMatchObject({
|
||||
src: '/image.png',
|
||||
objectUrl: 'blob:1',
|
||||
isLoading: false
|
||||
})
|
||||
|
||||
it('should handle URL release for non-existent cache entry', () => {
|
||||
const { releaseUrl } = useMediaCache()
|
||||
|
||||
// Should not throw error
|
||||
expect(() => releaseUrl('non-existent.jpg')).not.toThrow()
|
||||
})
|
||||
|
||||
it('should provide acquireUrl and releaseUrl methods', () => {
|
||||
const cache = useMediaCache()
|
||||
|
||||
expect(typeof cache.acquireUrl).toBe('function')
|
||||
expect(typeof cache.releaseUrl).toBe('function')
|
||||
expect(second).toEqual(first)
|
||||
expect(second.lastAccessed).toBe(100)
|
||||
expect(mockFetch).toHaveBeenCalledOnce()
|
||||
expect(mockFetch).toHaveBeenCalledWith('/image.png', {
|
||||
cache: 'force-cache'
|
||||
})
|
||||
})
|
||||
|
||||
it('stores an error entry when fetch fails', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
mockFetch.mockResolvedValue(response(false))
|
||||
const cache = await freshCache()
|
||||
|
||||
const entry = await cache.getCachedMedia('/missing.png')
|
||||
|
||||
expect(entry).toMatchObject({
|
||||
src: '/missing.png',
|
||||
error: true,
|
||||
isLoading: false
|
||||
})
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Failed to cache media:',
|
||||
'/missing.png',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
|
||||
it('ref-counts acquired object URLs and removes the cache entry on final release', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache()
|
||||
await cache.getCachedMedia('/image.png')
|
||||
|
||||
expect(cache.acquireUrl('/image.png')).toBe('blob:1')
|
||||
expect(cache.acquireUrl('/image.png')).toBe('blob:1')
|
||||
cache.releaseUrl('/image.png')
|
||||
expect(mockRevokeObjectURL).not.toHaveBeenCalled()
|
||||
expect(cache.cache.has('/image.png')).toBe(true)
|
||||
|
||||
cache.releaseUrl('/image.png')
|
||||
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
expect(cache.cache.has('/image.png')).toBe(false)
|
||||
})
|
||||
|
||||
it('returns undefined when acquiring a URL that is not cached', async () => {
|
||||
const cache = await freshCache()
|
||||
|
||||
expect(cache.acquireUrl('/missing.png')).toBeUndefined()
|
||||
cache.releaseUrl('/missing.png')
|
||||
expect(mockRevokeObjectURL).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('expires old cache entries during scheduled cleanup', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache({ maxAge: 100 })
|
||||
await cache.getCachedMedia('/old.png')
|
||||
|
||||
vi.setSystemTime(200)
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(cache.cache.has('/old.png')).toBe(false)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
})
|
||||
|
||||
it('keeps expired entries while their object URL is still acquired', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache({ maxAge: 100 })
|
||||
await cache.getCachedMedia('/held.png')
|
||||
cache.acquireUrl('/held.png')
|
||||
|
||||
vi.setSystemTime(200)
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(cache.cache.has('/held.png')).toBe(true)
|
||||
expect(mockRevokeObjectURL).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('removes the oldest unused entries when the cache is over size', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache({ maxSize: 1, maxAge: 1_000_000 })
|
||||
await cache.getCachedMedia('/old.png')
|
||||
vi.setSystemTime(1)
|
||||
await cache.getCachedMedia('/new.png')
|
||||
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(cache.cache.has('/old.png')).toBe(false)
|
||||
expect(cache.cache.has('/new.png')).toBe(true)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
})
|
||||
|
||||
it('clears all cached URLs on demand and before unload', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache()
|
||||
await cache.getCachedMedia('/first.png')
|
||||
await cache.getCachedMedia('/second.png')
|
||||
|
||||
cache.clearCache()
|
||||
|
||||
expect(cache.cache.size).toBe(0)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:2')
|
||||
|
||||
await cache.getCachedMedia('/third.png')
|
||||
window.dispatchEvent(new Event('beforeunload'))
|
||||
|
||||
expect(cache.cache.size).toBe(0)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:3')
|
||||
})
|
||||
})
|
||||
|
||||
162
src/services/nodeHelpService.test.ts
Normal file
162
src/services/nodeHelpService.test.ts
Normal file
@@ -0,0 +1,162 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { api } from '@/scripts/api'
|
||||
import { nodeHelpService } from '@/services/nodeHelpService'
|
||||
import type { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
fileURL: vi.fn((path: string) => `/files${path}`)
|
||||
}
|
||||
}))
|
||||
|
||||
describe('nodeHelpService', () => {
|
||||
const mockFetch = vi.fn<typeof fetch>()
|
||||
|
||||
function nodeDef(options: {
|
||||
name?: string
|
||||
python_module?: string
|
||||
description?: string
|
||||
}): ComfyNodeDefImpl {
|
||||
return {
|
||||
name: options.name ?? 'TestNode',
|
||||
display_name: options.name ?? 'Test Node',
|
||||
category: 'test',
|
||||
python_module: options.python_module ?? 'nodes',
|
||||
description: options.description ?? ''
|
||||
} as ComfyNodeDefImpl
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFetch.mockReset()
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
})
|
||||
|
||||
it('returns blueprint descriptions without fetching markdown', async () => {
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
python_module: 'blueprint',
|
||||
description: 'Saved workflow help'
|
||||
}),
|
||||
'en'
|
||||
)
|
||||
|
||||
expect(help).toBe('Saved workflow help')
|
||||
expect(mockFetch).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('fetches localized core node markdown', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
new Response('Core help', {
|
||||
headers: { 'content-type': 'text/markdown' }
|
||||
})
|
||||
)
|
||||
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({ name: 'LoadImage' }),
|
||||
'zh'
|
||||
)
|
||||
|
||||
expect(api.fileURL).toHaveBeenCalledWith('/docs/LoadImage/zh.md')
|
||||
expect(mockFetch).toHaveBeenCalledWith('/files/docs/LoadImage/zh.md')
|
||||
expect(help).toBe('Core help')
|
||||
})
|
||||
|
||||
it('rejects core node HTML fallbacks', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
new Response('<html></html>', {
|
||||
headers: { 'content-type': 'text/html' },
|
||||
statusText: 'OK'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
nodeHelpService.fetchNodeHelp(nodeDef({ name: 'PreviewImage' }), 'en')
|
||||
).rejects.toThrow('OK')
|
||||
})
|
||||
|
||||
it('uses the default missing-help error for empty markdown responses', async () => {
|
||||
mockFetch.mockResolvedValueOnce(new Response(''))
|
||||
|
||||
await expect(
|
||||
nodeHelpService.fetchNodeHelp(nodeDef({ name: 'EmptyHelp' }), 'en')
|
||||
).rejects.toThrow('Help not found')
|
||||
})
|
||||
|
||||
it('fetches custom node localized markdown before fallback markdown', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
new Response('Custom localized help', {
|
||||
headers: { 'content-type': 'text/markdown' }
|
||||
})
|
||||
)
|
||||
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
name: 'CustomNode',
|
||||
python_module: 'custom_nodes.ComfyUI-TestPack@1.2.3.nodes'
|
||||
}),
|
||||
'ja'
|
||||
)
|
||||
|
||||
expect(api.fileURL).toHaveBeenCalledWith(
|
||||
'/extensions/ComfyUI-TestPack/docs/CustomNode/ja.md'
|
||||
)
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1)
|
||||
expect(help).toBe('Custom localized help')
|
||||
})
|
||||
|
||||
it('falls back to custom node default markdown after locale miss', async () => {
|
||||
mockFetch
|
||||
.mockResolvedValueOnce(
|
||||
new Response('Not found', {
|
||||
status: 404,
|
||||
statusText: 'Locale missing'
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
new Response('Custom fallback help', {
|
||||
headers: { 'content-type': 'text/markdown' }
|
||||
})
|
||||
)
|
||||
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
name: 'CustomNode',
|
||||
python_module: 'custom_nodes.TestPack.nodes'
|
||||
}),
|
||||
'fr'
|
||||
)
|
||||
|
||||
expect(api.fileURL).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'/extensions/TestPack/docs/CustomNode/fr.md'
|
||||
)
|
||||
expect(api.fileURL).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'/extensions/TestPack/docs/CustomNode.md'
|
||||
)
|
||||
expect(help).toBe('Custom fallback help')
|
||||
})
|
||||
|
||||
it('reports the locale miss when the custom fallback is empty', async () => {
|
||||
mockFetch
|
||||
.mockResolvedValueOnce(
|
||||
new Response('Not found', {
|
||||
status: 404,
|
||||
statusText: 'Locale missing'
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(new Response(''))
|
||||
|
||||
await expect(
|
||||
nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
name: 'CustomNode',
|
||||
python_module: 'custom_nodes.TestPack.nodes'
|
||||
}),
|
||||
'de'
|
||||
)
|
||||
).rejects.toThrow('Locale missing')
|
||||
})
|
||||
})
|
||||
@@ -124,6 +124,80 @@ describe('nodeOrganizationService', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('organizeNodesTab', () => {
|
||||
it('returns no sections for an empty node list', () => {
|
||||
expect(nodeOrganizationService.organizeNodesTab([])).toEqual([])
|
||||
})
|
||||
|
||||
it('classifies blueprints, partner nodes, Comfy nodes, and extensions', () => {
|
||||
const sections = nodeOrganizationService.organizeNodesTab([
|
||||
createMockNodeDef({
|
||||
name: 'MyBlueprint',
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Blueprint,
|
||||
className: 'blueprint',
|
||||
displayText: 'Blueprint',
|
||||
badgeText: 'B'
|
||||
},
|
||||
isGlobal: false
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'ComfyBlueprint',
|
||||
python_module: 'blueprint.comfy',
|
||||
isGlobal: true
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'PartnerApi',
|
||||
category: 'api node/image'
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'CoreNode',
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Core,
|
||||
className: 'core',
|
||||
displayText: 'Core',
|
||||
badgeText: 'C'
|
||||
}
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'EssentialNode',
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Essentials,
|
||||
className: 'essentials',
|
||||
displayText: 'Essentials',
|
||||
badgeText: 'E'
|
||||
}
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'ExtensionNode'
|
||||
})
|
||||
])
|
||||
|
||||
expect(sections.map((section) => section.category)).toEqual([
|
||||
'blueprints',
|
||||
'comfyNodes',
|
||||
'partnerNodes',
|
||||
'extensions'
|
||||
])
|
||||
expect(sections[0].tree.children?.map((child) => child.key)).toEqual([
|
||||
'root/my-blueprints',
|
||||
'root/comfy-blueprints'
|
||||
])
|
||||
})
|
||||
|
||||
it('omits sections that have no matching nodes', () => {
|
||||
const sections = nodeOrganizationService.organizeNodesTab([
|
||||
createMockNodeDef({
|
||||
name: 'OnlyExtension'
|
||||
})
|
||||
])
|
||||
|
||||
expect(sections.map((section) => section.category)).toEqual([
|
||||
'extensions'
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getGroupingIcon', () => {
|
||||
it('should return strategy icon', () => {
|
||||
const icon = nodeOrganizationService.getGroupingIcon('category')
|
||||
@@ -216,6 +290,12 @@ describe('nodeOrganizationService', () => {
|
||||
expect(path).toEqual(['core', 'TestNode'])
|
||||
})
|
||||
|
||||
it('should handle custom_nodes without a package segment', () => {
|
||||
const nodeDef = createMockNodeDef({ python_module: 'custom_nodes' })
|
||||
const path = strategy?.getNodePath(nodeDef)
|
||||
expect(path).toEqual(['custom_nodes', 'TestNode'])
|
||||
})
|
||||
|
||||
it('should handle non-standard module paths', () => {
|
||||
const nodeDef = createMockNodeDef({
|
||||
python_module: 'some.other.module.path'
|
||||
@@ -282,6 +362,19 @@ describe('nodeOrganizationService', () => {
|
||||
const path = strategy?.getNodePath(nodeDef)
|
||||
expect(path).toEqual(['Unknown', 'TestNode'])
|
||||
})
|
||||
|
||||
it('should handle core source type', () => {
|
||||
const nodeDef = createMockNodeDef({
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Core,
|
||||
className: 'core',
|
||||
displayText: 'Core',
|
||||
badgeText: 'C'
|
||||
}
|
||||
})
|
||||
const path = strategy?.getNodePath(nodeDef)
|
||||
expect(path).toEqual(['Core', 'TestNode'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('node name edge cases', () => {
|
||||
@@ -326,5 +419,14 @@ describe('nodeOrganizationService', () => {
|
||||
expect(strategy?.compare(nodeA, nodeB)).toBeGreaterThan(0)
|
||||
expect(strategy?.compare(nodeB, nodeA)).toBeLessThan(0)
|
||||
})
|
||||
|
||||
it('alphabetical sort handles missing display names', () => {
|
||||
const strategy =
|
||||
nodeOrganizationService.getSortingStrategy('alphabetical')
|
||||
const nodeA = createMockNodeDef({ display_name: undefined })
|
||||
const nodeB = createMockNodeDef({ display_name: undefined })
|
||||
|
||||
expect(strategy?.compare(nodeA, nodeB)).toBe(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
172
src/services/subgraphService.test.ts
Normal file
172
src/services/subgraphService.test.ts
Normal file
@@ -0,0 +1,172 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type {
|
||||
ExportedSubgraph,
|
||||
ExportedSubgraphInstance,
|
||||
Subgraph
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
addNodeDef: vi.fn(),
|
||||
createSubgraph: vi.fn((subgraph: unknown) => ({
|
||||
createdFrom: subgraph
|
||||
})),
|
||||
registerSubgraphNodeDef: vi.fn(),
|
||||
subgraphs: new Map<string, unknown>()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/nodeDefStore', () => ({
|
||||
useNodeDefStore: () => ({
|
||||
addNodeDef: mocks.addNodeDef
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
rootGraph: {
|
||||
subgraphs: mocks.subgraphs,
|
||||
createSubgraph: mocks.createSubgraph
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./litegraphService', () => ({
|
||||
useLitegraphService: () => ({
|
||||
registerSubgraphNodeDef: mocks.registerSubgraphNodeDef
|
||||
})
|
||||
}))
|
||||
|
||||
const { useSubgraphService } = await import('./subgraphService')
|
||||
|
||||
function createExportedSubgraph(
|
||||
overrides: Partial<ExportedSubgraph> = {}
|
||||
): ExportedSubgraph {
|
||||
return {
|
||||
id: 'subgraph-1',
|
||||
name: 'Test Subgraph',
|
||||
...overrides
|
||||
} as ExportedSubgraph
|
||||
}
|
||||
|
||||
function createWorkflow(subgraphs?: ExportedSubgraph[]): ComfyWorkflowJSON {
|
||||
return {
|
||||
definitions: subgraphs
|
||||
? {
|
||||
subgraphs
|
||||
}
|
||||
: undefined
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
}
|
||||
|
||||
describe('useSubgraphService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mocks.subgraphs.clear()
|
||||
})
|
||||
|
||||
it('registers a new subgraph node definition', () => {
|
||||
const service = useSubgraphService()
|
||||
const subgraph = { id: 'runtime-subgraph' } as unknown as Subgraph
|
||||
const exportedSubgraph = createExportedSubgraph()
|
||||
|
||||
service.registerNewSubgraph(subgraph, exportedSubgraph)
|
||||
|
||||
expect(mocks.addNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'subgraph-1',
|
||||
display_name: 'Test Subgraph',
|
||||
description: 'Subgraph node for Test Subgraph',
|
||||
category: 'subgraph',
|
||||
output_node: false,
|
||||
python_module: 'nodes'
|
||||
})
|
||||
)
|
||||
|
||||
const [nodeDef, registeredSubgraph, instanceData] = mocks
|
||||
.registerSubgraphNodeDef.mock.calls[0] as [
|
||||
ComfyNodeDefV1,
|
||||
Subgraph,
|
||||
ExportedSubgraphInstance
|
||||
]
|
||||
|
||||
expect(nodeDef.name).toBe('subgraph-1')
|
||||
expect(registeredSubgraph).toBe(subgraph)
|
||||
expect(instanceData).toMatchObject({
|
||||
id: -1,
|
||||
type: 'subgraph-1',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
inputs: [],
|
||||
outputs: []
|
||||
})
|
||||
})
|
||||
|
||||
it('uses an exported description when present', () => {
|
||||
const service = useSubgraphService()
|
||||
const subgraph = { id: 'runtime-subgraph' } as unknown as Subgraph
|
||||
|
||||
service.registerNewSubgraph(
|
||||
subgraph,
|
||||
createExportedSubgraph({
|
||||
description: 'Reusable workflow section'
|
||||
})
|
||||
)
|
||||
|
||||
expect(mocks.addNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
description: 'Reusable workflow section'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when workflow data has no subgraph definitions', () => {
|
||||
const service = useSubgraphService()
|
||||
|
||||
service.loadSubgraphs(createWorkflow())
|
||||
|
||||
expect(mocks.addNodeDef).not.toHaveBeenCalled()
|
||||
expect(mocks.registerSubgraphNodeDef).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('registers existing root graph subgraphs from workflow data', () => {
|
||||
const service = useSubgraphService()
|
||||
const subgraph = { id: 'existing-subgraph' } as unknown as Subgraph
|
||||
const exportedSubgraph = createExportedSubgraph()
|
||||
mocks.subgraphs.set('subgraph-1', subgraph)
|
||||
|
||||
service.loadSubgraphs(createWorkflow([exportedSubgraph]))
|
||||
|
||||
expect(mocks.createSubgraph).not.toHaveBeenCalled()
|
||||
expect(mocks.registerSubgraphNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'subgraph-1'
|
||||
}),
|
||||
subgraph,
|
||||
expect.objectContaining({
|
||||
type: 'subgraph-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('creates missing root graph subgraphs from workflow data', () => {
|
||||
const service = useSubgraphService()
|
||||
const exportedSubgraph = createExportedSubgraph()
|
||||
|
||||
service.loadSubgraphs(createWorkflow([exportedSubgraph]))
|
||||
|
||||
expect(mocks.createSubgraph).toHaveBeenCalledWith(exportedSubgraph)
|
||||
expect(mocks.registerSubgraphNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'subgraph-1'
|
||||
}),
|
||||
expect.objectContaining({
|
||||
createdFrom: exportedSubgraph
|
||||
}),
|
||||
expect.objectContaining({
|
||||
type: 'subgraph-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
300
src/stores/assetExportStore.test.ts
Normal file
300
src/stores/assetExportStore.test.ts
Normal file
@@ -0,0 +1,300 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type * as VueUse from '@vueuse/core'
|
||||
|
||||
import type { AssetExportWsMessage } from '@/schemas/apiSchema'
|
||||
import { api } from '@/scripts/api'
|
||||
import type { TaskId } from '@/platform/tasks/services/taskService'
|
||||
import { useAssetExportStore } from '@/stores/assetExportStore'
|
||||
|
||||
const { getExportDownloadUrl, getTask, toastAdd, intervalState } = vi.hoisted(
|
||||
() => ({
|
||||
getExportDownloadUrl: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
toastAdd: vi.fn(),
|
||||
intervalState: { cb: null as null | (() => void) }
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@vueuse/core', async (importOriginal) => ({
|
||||
...(await importOriginal<typeof VueUse>()),
|
||||
useIntervalFn: (cb: () => void) => {
|
||||
intervalState.cb = cb
|
||||
return { pause: vi.fn(), resume: vi.fn() }
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: { addEventListener: vi.fn() }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: { getExportDownloadUrl }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/tasks/services/taskService', () => ({
|
||||
taskService: { getTask }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: toastAdd })
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => key
|
||||
}))
|
||||
|
||||
function wsMessage(
|
||||
over: Partial<AssetExportWsMessage> = {}
|
||||
): AssetExportWsMessage {
|
||||
return {
|
||||
task_id: 'task-1',
|
||||
export_name: 'export.zip',
|
||||
assets_total: 10,
|
||||
assets_attempted: 5,
|
||||
assets_failed: 0,
|
||||
bytes_total: 1000,
|
||||
bytes_processed: 500,
|
||||
progress: 0.5,
|
||||
status: 'running',
|
||||
...over
|
||||
}
|
||||
}
|
||||
|
||||
const taskId = (id: string) => id as TaskId
|
||||
|
||||
/**
|
||||
* Build a store and an `emit` bound to the real `asset_export` listener the
|
||||
* store registers on `api`, so tests drive the state machine through its
|
||||
* actual entry point rather than a private method.
|
||||
*/
|
||||
function setup() {
|
||||
const store = useAssetExportStore()
|
||||
const entry = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find((c) => c[0] === 'asset_export')
|
||||
const handler = entry![1] as (e: { detail: AssetExportWsMessage }) => void
|
||||
const emit = (msg: AssetExportWsMessage) => handler({ detail: msg })
|
||||
// Run the polling tick that `useIntervalFn` would normally fire, and let its
|
||||
// async work settle.
|
||||
const runPoll = async () => {
|
||||
intervalState.cb?.()
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
return { store, emit, runPoll }
|
||||
}
|
||||
|
||||
const STALE_AGO_MS = 20_000
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
vi.mocked(api.addEventListener).mockClear()
|
||||
getExportDownloadUrl
|
||||
.mockReset()
|
||||
.mockResolvedValue({ url: 'https://example.com/export.zip' })
|
||||
getTask.mockReset()
|
||||
toastAdd.mockReset()
|
||||
})
|
||||
|
||||
describe('assetExportStore', () => {
|
||||
it('tracks a new export as created and is idempotent', () => {
|
||||
const { store } = setup()
|
||||
|
||||
store.trackExport(taskId('t1'))
|
||||
store.trackExport(taskId('t1'))
|
||||
|
||||
expect(store.exportList).toHaveLength(1)
|
||||
expect(store.exportList[0].status).toBe('created')
|
||||
expect(store.hasExports).toBe(true)
|
||||
expect(store.hasActiveExports).toBe(true)
|
||||
})
|
||||
|
||||
it('separates active from finished exports by status', () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ task_id: 'running', status: 'running' }))
|
||||
emit(
|
||||
wsMessage({ task_id: 'failed', status: 'failed', export_name: 'f.zip' })
|
||||
)
|
||||
|
||||
expect(store.activeExports.map((e) => e.taskId)).toEqual(['running'])
|
||||
expect(store.finishedExports.map((e) => e.taskId)).toEqual(['failed'])
|
||||
})
|
||||
|
||||
it('updates an export from successive websocket messages', () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ progress: 0.5, status: 'running' }))
|
||||
emit(wsMessage({ progress: 0.9, status: 'running' }))
|
||||
|
||||
expect(store.exportList).toHaveLength(1)
|
||||
expect(store.exportList[0].progress).toBe(0.9)
|
||||
})
|
||||
|
||||
it('ignores updates for an export already completed and downloaded', async () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ status: 'completed' }))
|
||||
await Promise.resolve()
|
||||
const triggeredCalls = getExportDownloadUrl.mock.calls.length
|
||||
|
||||
// A late 'running' message must not revive a completed+downloaded export
|
||||
emit(wsMessage({ status: 'running', progress: 0.1 }))
|
||||
|
||||
expect(store.exportList[0].status).toBe('completed')
|
||||
expect(getExportDownloadUrl).toHaveBeenCalledTimes(triggeredCalls)
|
||||
})
|
||||
|
||||
it('falls back to the prior export name when a message omits it', async () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ status: 'running', progress: 0.4 }))
|
||||
emit(
|
||||
wsMessage({ status: 'running', export_name: undefined, progress: 0.6 })
|
||||
)
|
||||
|
||||
expect(store.exportList[0].exportName).toBe('export.zip')
|
||||
})
|
||||
|
||||
it('falls back to a blank export name when no message has named it', () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ export_name: undefined, status: 'running' }))
|
||||
|
||||
expect(store.exportList[0].exportName).toBe('')
|
||||
})
|
||||
|
||||
it('triggers a download for a named export and clears prior errors', async () => {
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
|
||||
expect(getExportDownloadUrl).toHaveBeenCalledWith('export.zip')
|
||||
expect(exp.downloadTriggered).toBe(true)
|
||||
expect(exp.downloadError).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not re-trigger a download unless forced', async () => {
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
exp.downloadTriggered = true
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
expect(getExportDownloadUrl).not.toHaveBeenCalled()
|
||||
|
||||
await store.triggerDownload(exp, true)
|
||||
expect(getExportDownloadUrl).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('records a download error and surfaces a toast on failure', async () => {
|
||||
getExportDownloadUrl.mockRejectedValueOnce(new Error('network down'))
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
|
||||
expect(exp.downloadError).toBe('network down')
|
||||
expect(exp.downloadTriggered).toBe(false)
|
||||
expect(toastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('records a string download error', async () => {
|
||||
getExportDownloadUrl.mockRejectedValueOnce('offline')
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
|
||||
expect(exp.downloadError).toBe('offline')
|
||||
})
|
||||
|
||||
it('clears finished exports while keeping active ones', () => {
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ task_id: 'a', status: 'running' }))
|
||||
emit(wsMessage({ task_id: 'b', status: 'failed', export_name: 'b.zip' }))
|
||||
|
||||
store.clearFinishedExports()
|
||||
|
||||
expect(store.exportList.map((e) => e.taskId)).toEqual(['a'])
|
||||
})
|
||||
|
||||
it('does not poll when no active export is stale', async () => {
|
||||
const { emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(getTask).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('reconciles a stale export from the task service result', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockResolvedValue({
|
||||
status: 'completed',
|
||||
result: { export_name: 'reconciled.zip', assets_total: 10 }
|
||||
})
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(getTask).toHaveBeenCalledWith('task-1')
|
||||
expect(store.exportList[0].status).toBe('completed')
|
||||
expect(store.exportList[0].exportName).toBe('reconciled.zip')
|
||||
})
|
||||
|
||||
it('leaves a stale export active when the task is still running', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockResolvedValue({ status: 'running' })
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(store.exportList[0].status).toBe('running')
|
||||
})
|
||||
|
||||
it('reconciles a stale failed export using existing counters', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(
|
||||
wsMessage({
|
||||
assets_attempted: 4,
|
||||
assets_failed: 1,
|
||||
status: 'running'
|
||||
})
|
||||
)
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockResolvedValue({
|
||||
status: 'failed',
|
||||
result: { error: 'failed in result' }
|
||||
})
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(store.exportList[0]).toMatchObject({
|
||||
assetsAttempted: 4,
|
||||
assetsFailed: 1,
|
||||
error: 'failed in result',
|
||||
status: 'failed'
|
||||
})
|
||||
})
|
||||
|
||||
it('leaves a stale export untouched when the task lookup fails', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockRejectedValue(new Error('task not found'))
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(store.exportList[0].status).toBe('running')
|
||||
})
|
||||
})
|
||||
@@ -2,8 +2,10 @@ import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { MissingNodeType } from '@/types/comfy'
|
||||
import { createNodeExecutionId } from '@/types/nodeIdentification'
|
||||
import type { NodeLocatorId } from '@/types/nodeIdentification'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/i18n', () => ({
|
||||
@@ -15,6 +17,53 @@ vi.mock('@/platform/distribution/types', () => ({
|
||||
}))
|
||||
|
||||
const mockShowErrorsTab = vi.hoisted(() => ({ value: false }))
|
||||
const {
|
||||
mockApp,
|
||||
mockCanvasStore,
|
||||
mockExecutionIdToNodeLocatorId,
|
||||
mockGetExecutionIdByNode,
|
||||
mockGetNodeByExecutionId,
|
||||
mockWorkflowStore
|
||||
} = vi.hoisted(() => ({
|
||||
mockApp: {
|
||||
isGraphReady: true,
|
||||
rootGraph: {}
|
||||
},
|
||||
mockCanvasStore: {
|
||||
currentGraph: undefined as object | undefined
|
||||
},
|
||||
mockExecutionIdToNodeLocatorId: vi.fn(
|
||||
(_rootGraph: unknown, id: string) => id as NodeLocatorId
|
||||
),
|
||||
mockGetExecutionIdByNode: vi.fn(),
|
||||
mockGetNodeByExecutionId: vi.fn(),
|
||||
mockWorkflowStore: {
|
||||
nodeLocatorIdToNodeId: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({ app: mockApp }))
|
||||
|
||||
vi.mock('@/renderer/core/canvas/canvasStore', () => ({
|
||||
useCanvasStore: () => mockCanvasStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: () => mockWorkflowStore
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/graphTraversalUtil', () => ({
|
||||
executionIdToNodeLocatorId: (
|
||||
...args: Parameters<typeof mockExecutionIdToNodeLocatorId>
|
||||
) => mockExecutionIdToNodeLocatorId(...args),
|
||||
forEachNode: vi.fn(),
|
||||
getExecutionIdByNode: (
|
||||
...args: Parameters<typeof mockGetExecutionIdByNode>
|
||||
) => mockGetExecutionIdByNode(...args),
|
||||
getNodeByExecutionId: (
|
||||
...args: Parameters<typeof mockGetNodeByExecutionId>
|
||||
) => mockGetNodeByExecutionId(...args)
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/settingStore', () => ({
|
||||
useSettingStore: vi.fn(() => ({
|
||||
@@ -39,6 +88,21 @@ import { useExecutionErrorStore } from './executionErrorStore'
|
||||
import { useMissingNodesErrorStore } from '@/platform/nodeReplacement/missingNodesErrorStore'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
beforeEach(() => {
|
||||
mockShowErrorsTab.value = false
|
||||
mockApp.isGraphReady = true
|
||||
mockCanvasStore.currentGraph = undefined
|
||||
mockExecutionIdToNodeLocatorId.mockImplementation(
|
||||
(_rootGraph: unknown, id: string) => id as NodeLocatorId
|
||||
)
|
||||
mockGetExecutionIdByNode.mockReset()
|
||||
mockGetNodeByExecutionId.mockReset()
|
||||
mockWorkflowStore.nodeLocatorIdToNodeId.mockImplementation(
|
||||
(locator: NodeLocatorId) =>
|
||||
toNodeId(String(locator).split(':').at(-1) ?? locator)
|
||||
)
|
||||
})
|
||||
|
||||
describe('executionErrorStore — node error operations', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
@@ -144,6 +208,31 @@ describe('executionErrorStore — node error operations', () => {
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('does nothing when the requested slot has no errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
'123': {
|
||||
errors: [
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: 'Max exceeded',
|
||||
details: '',
|
||||
extra_info: { input_name: 'otherSlot' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'TestNode'
|
||||
}
|
||||
}
|
||||
|
||||
store.clearSimpleNodeErrors(
|
||||
createNodeExecutionId([toNodeId(123)]),
|
||||
'testSlot'
|
||||
)
|
||||
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('preserves complex errors when slot has both simple and complex errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
@@ -388,6 +477,359 @@ describe('executionErrorStore — node error operations', () => {
|
||||
expect(store.lastNodeErrors).not.toBeNull()
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('keeps numeric range errors when no range options prove them valid', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
'123': {
|
||||
errors: [
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: '...',
|
||||
details: '',
|
||||
extra_info: { input_name: 'testWidget' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'TestNode'
|
||||
}
|
||||
}
|
||||
|
||||
store.clearWidgetRelatedErrors(
|
||||
createNodeExecutionId([toNodeId(123)]),
|
||||
'testWidget',
|
||||
'testWidget',
|
||||
15
|
||||
)
|
||||
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('is a no-op when the target execution id has no node error entry', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
'999': {
|
||||
errors: [
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: '...',
|
||||
details: '',
|
||||
extra_info: { input_name: 'testWidget' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'TestNode'
|
||||
}
|
||||
}
|
||||
|
||||
store.clearWidgetRelatedErrors(
|
||||
createNodeExecutionId([toNodeId(123)]),
|
||||
'testWidget',
|
||||
'testWidget',
|
||||
15,
|
||||
{ max: 10 }
|
||||
)
|
||||
|
||||
expect(store.lastNodeErrors?.['123']).toBeUndefined()
|
||||
expect(store.lastNodeErrors?.['999'].errors).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startup clearing', () => {
|
||||
it('clears execution-start errors and closes the overlay when node errors are empty', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '1' })
|
||||
store.lastPromptError = fromAny({ message: 'prompt failed' })
|
||||
store.lastNodeErrors = {}
|
||||
store.showErrorOverlay()
|
||||
|
||||
store.clearExecutionStartErrors()
|
||||
|
||||
expect(store.lastExecutionError).toBeNull()
|
||||
expect(store.lastPromptError).toBeNull()
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('keeps the overlay open when node errors remain after execution start', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '1' })
|
||||
store.lastPromptError = fromAny({ message: 'prompt failed' })
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
store.showErrorOverlay()
|
||||
|
||||
store.clearExecutionStartErrors()
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('executionErrorStore derived graph state', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
it('derives execution error node ids through locator mapping', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockExecutionIdToNodeLocatorId.mockReturnValue(
|
||||
fromAny<NodeLocatorId, string>('graph:7')
|
||||
)
|
||||
store.lastExecutionError = fromAny({ node_id: '7' })
|
||||
|
||||
expect(store.lastExecutionErrorNodeId).toBe(toNodeId(7))
|
||||
})
|
||||
|
||||
it('returns null when there is no execution error locator', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '7' })
|
||||
mockExecutionIdToNodeLocatorId.mockReturnValue(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
expect(store.lastExecutionErrorNodeId).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null when there is no execution error', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
|
||||
expect(store.lastExecutionErrorNodeId).toBeNull()
|
||||
})
|
||||
|
||||
it('combines prompt, node, execution, and missing-node error counts', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const missingNodesStore = useMissingNodesErrorStore()
|
||||
store.lastPromptError = fromAny({ message: 'prompt failed' })
|
||||
store.lastExecutionError = fromAny({ node_id: null })
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
},
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: 'Too large',
|
||||
details: '',
|
||||
extra_info: { input_name: 'y' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
missingNodesStore.setMissingNodeTypes(
|
||||
fromAny<MissingNodeType[], unknown>([{ type: 'MissingNode', hint: '' }])
|
||||
)
|
||||
|
||||
expect(store.hasPromptError).toBe(true)
|
||||
expect(store.hasNodeError).toBe(true)
|
||||
expect(store.hasExecutionError).toBe(true)
|
||||
expect(store.hasAnyError).toBe(true)
|
||||
expect(store.allErrorExecutionIds).toEqual(['1'])
|
||||
expect(store.totalErrorCount).toBe(5)
|
||||
})
|
||||
|
||||
it('reports empty derived state when there are no errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
|
||||
expect(store.hasNodeError).toBe(false)
|
||||
expect(store.allErrorExecutionIds).toEqual([])
|
||||
expect(store.totalErrorCount).toBe(0)
|
||||
})
|
||||
|
||||
it('includes defined execution node ids in the error id list', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '2' })
|
||||
|
||||
expect(store.allErrorExecutionIds).toEqual(['2'])
|
||||
})
|
||||
|
||||
it('excludes undefined execution node ids from the error id list', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: undefined })
|
||||
|
||||
expect(store.allErrorExecutionIds).toEqual([])
|
||||
})
|
||||
|
||||
it('collects active graph node ids for validation and execution errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const activeGraph = {}
|
||||
mockCanvasStore.currentGraph = activeGraph
|
||||
mockGetNodeByExecutionId.mockImplementation((_rootGraph, id: string) => ({
|
||||
id: toNodeId(id),
|
||||
graph: activeGraph
|
||||
}))
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
store.lastExecutionError = fromAny({ node_id: '2' })
|
||||
|
||||
expect([...store.activeGraphErrorNodeIds].sort()).toEqual(['1', '2'])
|
||||
})
|
||||
|
||||
it('falls back to the root graph when there is no current canvas graph', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockCanvasStore.currentGraph = undefined
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
id: toNodeId(1),
|
||||
graph: mockApp.rootGraph
|
||||
})
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
|
||||
expect([...store.activeGraphErrorNodeIds]).toEqual(['1'])
|
||||
})
|
||||
|
||||
it('ignores graph errors outside the active graph', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const activeGraph = {}
|
||||
mockCanvasStore.currentGraph = activeGraph
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
id: toNodeId(1),
|
||||
graph: {}
|
||||
})
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
store.lastExecutionError = fromAny({ node_id: '1' })
|
||||
|
||||
expect(store.activeGraphErrorNodeIds.size).toBe(0)
|
||||
})
|
||||
|
||||
it('returns no active graph node ids before the graph is ready', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockApp.isGraphReady = false
|
||||
store.lastExecutionError = fromAny({ node_id: '2' })
|
||||
|
||||
expect(store.activeGraphErrorNodeIds.size).toBe(0)
|
||||
})
|
||||
|
||||
it('maps node errors by locator and checks slots', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const nodeError = {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
mockExecutionIdToNodeLocatorId.mockImplementation((_rootGraph, id) =>
|
||||
id === 'missing'
|
||||
? fromAny<NodeLocatorId, undefined>(undefined)
|
||||
: fromAny<NodeLocatorId, string>(`locator:${id}`)
|
||||
)
|
||||
store.lastNodeErrors = {
|
||||
'1': nodeError,
|
||||
missing: nodeError
|
||||
}
|
||||
|
||||
const locator = fromAny<NodeLocatorId, string>('locator:1')
|
||||
expect(store.getNodeErrors(locator)).toEqual(nodeError)
|
||||
expect(store.slotHasError(locator, 'x')).toBe(true)
|
||||
expect(store.slotHasError(locator, 'y')).toBe(false)
|
||||
expect(
|
||||
store.getNodeErrors(fromAny<NodeLocatorId, string>('locator:missing'))
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns no slot error when there are no node errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
|
||||
expect(
|
||||
store.slotHasError(fromAny<NodeLocatorId, string>('locator:1'), 'x')
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('detects container nodes with internal errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const node = fromAny<LGraphNode, unknown>({})
|
||||
mockGetExecutionIdByNode.mockReturnValueOnce(undefined)
|
||||
|
||||
expect(store.isContainerWithInternalError(node)).toBe(false)
|
||||
|
||||
store.lastNodeErrors = {
|
||||
'1:2': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
mockGetExecutionIdByNode.mockReturnValue(
|
||||
createNodeExecutionId([toNodeId(1)])
|
||||
)
|
||||
|
||||
expect(store.isContainerWithInternalError(node)).toBe(true)
|
||||
})
|
||||
|
||||
it('does not report container errors before the graph is ready', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockApp.isGraphReady = false
|
||||
|
||||
expect(
|
||||
store.isContainerWithInternalError(fromAny<LGraphNode, unknown>({}))
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -457,6 +899,23 @@ describe('surfaceMissingModels — silent option', () => {
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('does NOT open error overlay when the setting is disabled', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockShowErrorsTab.value = false
|
||||
store.surfaceMissingModels([
|
||||
fromAny({
|
||||
name: 'model.safetensors',
|
||||
nodeId: toNodeId('1'),
|
||||
nodeType: 'Loader',
|
||||
widgetName: 'ckpt',
|
||||
isMissing: true,
|
||||
isAssetSupported: false
|
||||
})
|
||||
])
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('surfaceMissingMedia — silent option', () => {
|
||||
@@ -525,6 +984,23 @@ describe('surfaceMissingMedia — silent option', () => {
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('does NOT open error overlay when the setting is disabled', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockShowErrorsTab.value = false
|
||||
store.surfaceMissingMedia([
|
||||
fromAny({
|
||||
name: 'photo.png',
|
||||
nodeId: toNodeId('1'),
|
||||
nodeType: 'LoadImage',
|
||||
widgetName: 'image',
|
||||
mediaType: 'image',
|
||||
isMissing: true
|
||||
})
|
||||
])
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearAllErrors', () => {
|
||||
|
||||
@@ -3,6 +3,7 @@ import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, ref } from 'vue'
|
||||
import { createMemoryHistory, createRouter } from 'vue-router'
|
||||
|
||||
import type * as VueRouter from 'vue-router'
|
||||
|
||||
@@ -102,12 +103,24 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
function makeSubgraph(id: string): Subgraph {
|
||||
return fromPartial<Subgraph>({
|
||||
id,
|
||||
isRootGraph: false,
|
||||
rootGraph: app.rootGraph,
|
||||
_nodes: [],
|
||||
nodes: []
|
||||
})
|
||||
}
|
||||
|
||||
async function makeDuplicatedNavigationFailure(): Promise<Error> {
|
||||
const router = createRouter({
|
||||
history: createMemoryHistory(),
|
||||
routes: [{ path: '/', component: {} }]
|
||||
})
|
||||
await router.push('/')
|
||||
const failure = await router.push('/')
|
||||
if (!failure) throw new Error('Expected duplicated navigation failure')
|
||||
return failure
|
||||
}
|
||||
|
||||
async function flushHashWatcher() {
|
||||
await nextTick()
|
||||
await Promise.resolve()
|
||||
@@ -118,6 +131,7 @@ describe('useSubgraphNavigationStore - navigateToHash validation', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(app.canvas.setGraph).mockReset()
|
||||
app.rootGraph.id = ids.root
|
||||
app.rootGraph.subgraphs.clear()
|
||||
app.canvas.subgraph = undefined
|
||||
@@ -230,6 +244,44 @@ describe('useSubgraphNavigationStore - navigateToHash validation', () => {
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('does not warn when recovery redirect hits a duplicated navigation', async () => {
|
||||
routerMocks.replace.mockRejectedValueOnce(
|
||||
await makeDuplicatedNavigationFailure()
|
||||
)
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
app.canvas.graph = makeSubgraph(ids.deletedSubgraph)
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
|
||||
expect(warnSpy).not.toHaveBeenCalledWith(
|
||||
'[subgraphNavigation] router.replace rejected during recovery',
|
||||
expect.any(Error)
|
||||
)
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('recovers to root when canvas is unavailable during redirect cleanup', async () => {
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
const canvas = appWithOptionalCanvas.canvas
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
try {
|
||||
routeHashRef.value = '#not-a-valid-uuid'
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('redirects when a workflow load resolves but the subgraph is still missing', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
workflowStoreState.openWorkflows = [
|
||||
@@ -304,4 +356,196 @@ describe('useSubgraphNavigationStore - navigateToHash validation', () => {
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('updateHash does nothing on initial load with an empty hash', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.replace).not.toHaveBeenCalled()
|
||||
expect(routerMocks.push).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updateHash follows a non-empty initial subgraph hash', async () => {
|
||||
const subgraph = makeSubgraph(ids.validSubgraph)
|
||||
app.rootGraph.subgraphs.set(subgraph.id, subgraph)
|
||||
vi.mocked(app.canvas.setGraph).mockImplementation((graph) => {
|
||||
app.canvas.graph = graph
|
||||
})
|
||||
routeHashRef.value = `#${ids.validSubgraph}`
|
||||
const store = useSubgraphNavigationStore()
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
|
||||
})
|
||||
|
||||
it('updateHash does not treat the initial root hash as a subgraph', async () => {
|
||||
routeHashRef.value = `#${ids.root}`
|
||||
app.canvas.graph = app.rootGraph
|
||||
const store = useSubgraphNavigationStore()
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(workflowStoreState.activeSubgraph).toBeUndefined()
|
||||
})
|
||||
|
||||
it('updateHash replaces an empty hash and pushes the active graph id', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
expect(routerMocks.push).toHaveBeenCalledWith(`#${ids.validSubgraph}`)
|
||||
})
|
||||
|
||||
it('updateHash skips router push when hash already matches the active graph', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = `#${ids.validSubgraph}`
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.push).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updateHash skips router push when the active graph has no id', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = '#old'
|
||||
app.canvas.graph = fromPartial<LGraph>({})
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.push).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updateHash warns when router push rejects', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
routerMocks.push.mockRejectedValueOnce(new Error('push failed'))
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = '#old'
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[subgraphNavigation] router.push rejected',
|
||||
expect.any(Error)
|
||||
)
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('updateHash ignores duplicated router push failures', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
routerMocks.push.mockRejectedValueOnce(
|
||||
await makeDuplicatedNavigationFailure()
|
||||
)
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = `#${ids.root}`
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(warnSpy).not.toHaveBeenCalled()
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('skips workflows without active state during hash recovery', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({ path: 'inactive.json' })
|
||||
]
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
})
|
||||
|
||||
it('skips workflow states and subgraphs that do not match the hash', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'other-workflow.json',
|
||||
activeState: {
|
||||
id: ids.validSubgraph,
|
||||
definitions: {
|
||||
subgraphs: [{ id: ids.validSubgraph }]
|
||||
}
|
||||
}
|
||||
})
|
||||
]
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
})
|
||||
|
||||
it('handles workflow states with no subgraph definitions during recovery', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'no-definitions.json',
|
||||
activeState: { id: ids.validSubgraph }
|
||||
})
|
||||
]
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
})
|
||||
|
||||
it('opens a workflow and navigates to the loaded root graph', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'root-workflow.json',
|
||||
activeState: {
|
||||
id: ids.deletedSubgraph,
|
||||
definitions: { subgraphs: [] }
|
||||
}
|
||||
})
|
||||
]
|
||||
workflowServiceMocks.openWorkflow.mockImplementation(async () => {
|
||||
app.rootGraph.id = ids.deletedSubgraph
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.root })
|
||||
})
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
)
|
||||
})
|
||||
|
||||
it('does not reset the graph when loaded workflow is already active', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'already-active.json',
|
||||
activeState: {
|
||||
id: ids.deletedSubgraph,
|
||||
definitions: { subgraphs: [] }
|
||||
}
|
||||
})
|
||||
]
|
||||
workflowServiceMocks.openWorkflow.mockImplementation(async () => {
|
||||
app.rootGraph.id = ids.deletedSubgraph
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.deletedSubgraph })
|
||||
})
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(workflowServiceMocks.openWorkflow).toHaveBeenCalled()
|
||||
)
|
||||
|
||||
expect(app.canvas.setGraph).not.toHaveBeenCalledWith(app.rootGraph)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
@@ -136,6 +137,23 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
})
|
||||
|
||||
describe('saveViewport', () => {
|
||||
it('does not save when canvas is unavailable', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const canvas = app.canvas
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
|
||||
try {
|
||||
store.saveViewport('root')
|
||||
|
||||
expect(store.viewportCache.has(':root')).toBe(false)
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('saves viewport state for root graph', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
mockCanvas.ds.state.scale = 2
|
||||
@@ -164,6 +182,42 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
})
|
||||
|
||||
describe('restoreViewport', () => {
|
||||
it('does nothing when canvas is unavailable', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const canvas = app.canvas
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
|
||||
try {
|
||||
store.restoreViewport('root')
|
||||
|
||||
expect(mockSetDirty).not.toHaveBeenCalled()
|
||||
expect(rafCallbacks).toHaveLength(0)
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('does not apply cached viewport when canvas disappears', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const canvas = app.canvas
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
store.viewportCache.set(':root', { scale: 2.5, offset: [150, 250] })
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
|
||||
try {
|
||||
store.restoreViewport('root')
|
||||
|
||||
expect(mockSetDirty).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('restores cached viewport', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
store.viewportCache.set(':root', { scale: 2.5, offset: [150, 250] })
|
||||
@@ -266,7 +320,10 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
expect(mockFitView).toHaveBeenCalledOnce()
|
||||
|
||||
// User navigated away before the inner RAF fired
|
||||
mockCanvas.subgraph = { id: 'different-graph' } as never
|
||||
mockCanvas.subgraph = fromPartial<Subgraph>({
|
||||
id: 'different-graph',
|
||||
isRootGraph: false
|
||||
})
|
||||
rafCallbacks[1](performance.now())
|
||||
|
||||
expect(mockRequestSlotSyncAll).not.toHaveBeenCalled()
|
||||
@@ -283,7 +340,10 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
expect(rafCallbacks).toHaveLength(1)
|
||||
|
||||
// Simulate graph switching away before rAF fires
|
||||
mockCanvas.subgraph = { id: 'different-graph' } as never
|
||||
mockCanvas.subgraph = fromPartial<Subgraph>({
|
||||
id: 'different-graph',
|
||||
isRootGraph: false
|
||||
})
|
||||
|
||||
rafCallbacks[0](performance.now())
|
||||
|
||||
@@ -341,6 +401,23 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
expect(mockCanvas.ds.offset).toEqual([100, 100])
|
||||
})
|
||||
|
||||
it('does not save the outgoing viewport while a workflow switch is blocked', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const subgraph = fromPartial<Subgraph>({
|
||||
id: 'sub1',
|
||||
isRootGraph: false,
|
||||
rootGraph: app.rootGraph
|
||||
})
|
||||
|
||||
store.saveCurrentViewport()
|
||||
store.viewportCache.clear()
|
||||
workflowStore.activeSubgraph = subgraph
|
||||
await nextTick()
|
||||
|
||||
expect(store.viewportCache.has(':root')).toBe(false)
|
||||
})
|
||||
|
||||
it('preserves pre-existing cache entries across workflow switches', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
|
||||
@@ -10,9 +10,14 @@ import {
|
||||
import type { ExportedSubgraph } from '@/lib/litegraph/src/types/serialisation'
|
||||
import { TemplateIncludeOnDistributionEnum } from '@/platform/workflow/templates/types/template'
|
||||
import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import type { GlobalSubgraphData } from '@/scripts/api'
|
||||
import { api } from '@/scripts/api'
|
||||
import { app as comfyApp } from '@/scripts/app'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useNodeDefStore } from '@/stores/nodeDefStore'
|
||||
import { useSubgraphStore } from '@/stores/subgraphStore'
|
||||
@@ -36,6 +41,7 @@ vi.mock('@/scripts/api', () => ({
|
||||
storeUserData: vi.fn(),
|
||||
listUserDataFullInfo: vi.fn(),
|
||||
getGlobalSubgraphs: vi.fn(),
|
||||
deleteUserData: vi.fn(),
|
||||
apiURL: vi.fn(),
|
||||
addEventListener: vi.fn()
|
||||
}
|
||||
@@ -98,6 +104,12 @@ describe('useSubgraphStore', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
store = useSubgraphStore()
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm: vi.fn(() => true)
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should allow publishing of a subgraph', async () => {
|
||||
@@ -134,6 +146,86 @@ describe('useSubgraphStore', () => {
|
||||
await store.publishSubgraph()
|
||||
expect(api.storeUserData).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('rejects publishing when a single subgraph node is not selected', async () => {
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set()
|
||||
|
||||
await expect(store.publishSubgraph()).rejects.toThrow(
|
||||
'Must have single SubgraphNode selected to publish'
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects publishing when serialization produces multiple nodes', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [subgraphNode.serialize(), subgraphNode.serialize()],
|
||||
subgraphs: []
|
||||
}))
|
||||
|
||||
await expect(store.publishSubgraph()).rejects.toThrow(
|
||||
'Must have single SubgraphNode selected to publish'
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects publishing when the serialized node is not a subgraph node', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas).draw = vi.fn()
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [{ ...subgraphNode.serialize(), type: 'missing' }],
|
||||
subgraphs: [fromAny<ExportedSubgraph, unknown>(subgraph.serialize())]
|
||||
}))
|
||||
|
||||
await expect(store.publishSubgraph('invalid')).rejects.toThrow(
|
||||
'Loaded subgraph blueprint does not contain valid subgraph'
|
||||
)
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not publish when the name prompt is cancelled', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [subgraphNode.serialize()],
|
||||
subgraphs: [fromAny<ExportedSubgraph, unknown>(subgraph.serialize())]
|
||||
}))
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => null),
|
||||
confirm: vi.fn(() => true)
|
||||
})
|
||||
)
|
||||
|
||||
await store.publishSubgraph()
|
||||
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not overwrite an existing blueprint when confirmation is cancelled', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [subgraphNode.serialize()],
|
||||
subgraphs: [fromAny<ExportedSubgraph, unknown>(subgraph.serialize())]
|
||||
}))
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'test'),
|
||||
confirm: vi.fn(() => false)
|
||||
})
|
||||
)
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
|
||||
await store.publishSubgraph('test')
|
||||
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should display published nodes in the node library', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
expect(
|
||||
@@ -148,6 +240,30 @@ describe('useSubgraphStore', () => {
|
||||
//check active graph
|
||||
expect(comfyApp.loadGraphData).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('switches into the nested subgraph when editing opens a wrapper graph', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const setGraph = vi.fn()
|
||||
const nested = { id: 'nested' }
|
||||
vi.mocked(comfyApp.canvas).graph = fromAny<
|
||||
NonNullable<typeof comfyApp.canvas.graph>,
|
||||
unknown
|
||||
>({
|
||||
nodes: [{ subgraph: nested }],
|
||||
setGraph
|
||||
})
|
||||
vi.mocked(comfyApp.canvas).setGraph = setGraph
|
||||
|
||||
await store.editBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')
|
||||
|
||||
expect(setGraph).toHaveBeenCalledWith(nested)
|
||||
})
|
||||
|
||||
it('throws when editing an unloaded blueprint', async () => {
|
||||
await expect(
|
||||
store.editBlueprint(BLUEPRINT_TYPE_PREFIX + 'missing')
|
||||
).rejects.toThrow('not yet loaded')
|
||||
})
|
||||
it('should allow subgraphs to be added to graph', async () => {
|
||||
//mock
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
@@ -166,6 +282,12 @@ describe('useSubgraphStore', () => {
|
||||
expect(second.nodes[0].id).not.toBe(-1)
|
||||
expect(second.definitions!.subgraphs![0].id).toBe('123')
|
||||
})
|
||||
|
||||
it('throws when getting an unloaded blueprint', () => {
|
||||
expect(() => store.getBlueprint(BLUEPRINT_TYPE_PREFIX + 'missing')).toThrow(
|
||||
'not yet loaded'
|
||||
)
|
||||
})
|
||||
it('should identify user blueprints as non-global', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
expect(store.isGlobalBlueprint('test')).toBe(false)
|
||||
@@ -188,6 +310,57 @@ describe('useSubgraphStore', () => {
|
||||
expect(store.isGlobalBlueprint('nonexistent')).toBe(false)
|
||||
})
|
||||
|
||||
describe('deleteBlueprint', () => {
|
||||
it('throws for unloaded blueprints', async () => {
|
||||
await expect(
|
||||
store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'missing')
|
||||
).rejects.toThrow('not yet loaded')
|
||||
})
|
||||
|
||||
it('does not delete global blueprints', async () => {
|
||||
await mockFetch(
|
||||
{},
|
||||
{
|
||||
global_bp: {
|
||||
name: 'Global Blueprint',
|
||||
info: { node_pack: 'comfy_essentials' },
|
||||
data: JSON.stringify(mockGraph)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'global_bp')
|
||||
|
||||
expect(api.deleteUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not delete when confirmation is cancelled', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm: vi.fn(() => false)
|
||||
})
|
||||
)
|
||||
|
||||
await store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')
|
||||
|
||||
expect(api.deleteUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('deletes user blueprints after confirmation', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
vi.mocked(api.deleteUserData).mockResolvedValue({
|
||||
status: 204
|
||||
} as Response)
|
||||
|
||||
await store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')
|
||||
|
||||
expect(api.deleteUserData).toHaveBeenCalledWith('subgraphs/test.json')
|
||||
expect(store.isUserBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isUserBlueprint', () => {
|
||||
it('should return true for user blueprints', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
@@ -285,6 +458,206 @@ describe('useSubgraphStore', () => {
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('continues when global blueprint discovery rejects', async () => {
|
||||
vi.mocked(api.listUserDataFullInfo).mockResolvedValue([])
|
||||
vi.mocked(api.getGlobalSubgraphs).mockRejectedValue(
|
||||
new Error('global down')
|
||||
)
|
||||
|
||||
await store.fetchSubgraphs()
|
||||
|
||||
expect(store.subgraphBlueprints).toEqual([])
|
||||
})
|
||||
|
||||
it('reports compact detail when more than three blueprints fail', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
await mockFetch(
|
||||
{},
|
||||
{
|
||||
a: { name: 'A', info: { node_pack: 'test' }, data: '' },
|
||||
b: { name: 'B', info: { node_pack: 'test' }, data: '' },
|
||||
c: { name: 'C', info: { node_pack: 'test' }, data: '' },
|
||||
d: { name: 'D', info: { node_pack: 'test' }, data: '' }
|
||||
}
|
||||
)
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledTimes(4)
|
||||
expect(useToastStore().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'x4' })
|
||||
)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('ignores invalid user blueprint files during fetch', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await mockFetch({
|
||||
'invalid.json': {
|
||||
nodes: [],
|
||||
definitions: { subgraphs: [] }
|
||||
}
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load subgraph blueprint',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(store.subgraphBlueprints).toHaveLength(0)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects loaded blueprints whose wrapper node does not reference a subgraph', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await mockFetch({
|
||||
'invalid-ref.json': {
|
||||
nodes: [{ id: 1, type: 'missing' }],
|
||||
definitions: { subgraphs: [{ id: 'present' }] }
|
||||
}
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load subgraph blueprint',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(store.subgraphBlueprints).toHaveLength(0)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects loaded blueprints without subgraph definitions', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await mockFetch({
|
||||
'missing-definitions.json': {
|
||||
nodes: [{ id: 1, type: 'missing' }]
|
||||
}
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load subgraph blueprint',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(store.subgraphBlueprints).toHaveLength(0)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects saving a blueprint whose active state has no subgraph definitions', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint?.changeTracker) throw new Error('Blueprint was not loaded')
|
||||
blueprint.changeTracker!.activeState = fromAny<ComfyWorkflowJSON, unknown>({
|
||||
nodes: [{ id: 1, type: '123' }]
|
||||
})
|
||||
|
||||
await expect(blueprint.save()).rejects.toThrow(
|
||||
'The root graph of a subgraph blueprint must consist of only a single subgraph node'
|
||||
)
|
||||
})
|
||||
|
||||
it('marks non-blueprint root nodes when saving an invalid blueprint', async () => {
|
||||
vi.mocked(comfyApp.canvas).draw = vi.fn()
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint?.changeTracker) throw new Error('Blueprint was not loaded')
|
||||
blueprint.changeTracker!.activeState = fromAny<ComfyWorkflowJSON, unknown>({
|
||||
nodes: [
|
||||
{ id: 1, type: '123' },
|
||||
{ id: 2, type: 'OtherNode' }
|
||||
],
|
||||
definitions: { subgraphs: [{ id: '123' }] }
|
||||
})
|
||||
|
||||
await expect(blueprint.save()).rejects.toThrow(
|
||||
'The root graph of a subgraph blueprint must consist of only a single subgraph node'
|
||||
)
|
||||
expect(comfyApp.canvas.draw).toHaveBeenCalledWith(true, true)
|
||||
})
|
||||
|
||||
it('does not save a loaded blueprint when first-save confirmation is cancelled', async () => {
|
||||
const confirm = vi.fn(() => false)
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm
|
||||
})
|
||||
)
|
||||
useSettingStore().settingValues['Comfy.Workflow.WarnBlueprintOverwrite'] =
|
||||
true
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint) throw new Error('Blueprint was not loaded')
|
||||
|
||||
const result = await blueprint.save()
|
||||
|
||||
expect(result).toBe(blueprint)
|
||||
expect(confirm).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'overwriteBlueprint',
|
||||
itemList: ['test']
|
||||
})
|
||||
)
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('saves a loaded blueprint after first-save confirmation', async () => {
|
||||
const confirm = vi.fn(() => true)
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm
|
||||
})
|
||||
)
|
||||
useSettingStore().settingValues['Comfy.Workflow.WarnBlueprintOverwrite'] =
|
||||
true
|
||||
vi.mocked(api.storeUserData).mockResolvedValue({
|
||||
status: 200,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
path: 'subgraphs/test.json',
|
||||
modified: Date.now(),
|
||||
size: 2
|
||||
})
|
||||
} as Response)
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint) throw new Error('Blueprint was not loaded')
|
||||
|
||||
await blueprint.save()
|
||||
|
||||
const [path, data, options] = vi.mocked(api.storeUserData).mock.calls[0]
|
||||
if (typeof data !== 'string') throw new Error('Expected saved JSON')
|
||||
expect(path).toBe('subgraphs/test.json')
|
||||
expect(JSON.parse(data)).toMatchObject({
|
||||
nodes: [{ type: '123', title: 'test' }],
|
||||
definitions: { subgraphs: [{ id: '123', name: 'test' }] }
|
||||
})
|
||||
expect(options).toEqual({
|
||||
overwrite: true,
|
||||
throwOnError: true,
|
||||
full_info: true
|
||||
})
|
||||
})
|
||||
|
||||
it('returns an already-loaded blueprint when loading without force', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint) throw new Error('Blueprint was not loaded')
|
||||
|
||||
await blueprint.load()
|
||||
|
||||
expect(api.getUserData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle global blueprint with rejected data promise gracefully', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
await mockFetch(
|
||||
@@ -406,6 +779,29 @@ describe('useSubgraphStore', () => {
|
||||
expect(nodeDef?.description).toBe('This is a test blueprint')
|
||||
})
|
||||
|
||||
it('does not copy workflowRendererVersion into subgraph metadata on load', async () => {
|
||||
await mockFetch({
|
||||
'metadata-load.json': {
|
||||
nodes: [{ type: '123' }],
|
||||
definitions: {
|
||||
subgraphs: [{ id: '123', extra: {} }]
|
||||
},
|
||||
extra: {
|
||||
BlueprintDescription: 'Loaded description',
|
||||
workflowRendererVersion: 'Vue'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const blueprint = store.getBlueprint(
|
||||
BLUEPRINT_TYPE_PREFIX + 'metadata-load'
|
||||
)
|
||||
|
||||
expect(blueprint.definitions!.subgraphs![0].extra).toEqual({
|
||||
BlueprintDescription: 'Loaded description'
|
||||
})
|
||||
})
|
||||
|
||||
it('should not duplicate metadata in both workflow extra and subgraph extra when publishing', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
@@ -465,6 +861,59 @@ describe('useSubgraphStore', () => {
|
||||
expect(subgraphExtra?.BlueprintDescription).toBeUndefined()
|
||||
expect(subgraphExtra?.BlueprintSearchAliases).toBeUndefined()
|
||||
})
|
||||
|
||||
it('keeps workflowRendererVersion in subgraph extra when publishing', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
subgraphNode.graph!.add(subgraphNode)
|
||||
|
||||
subgraph.extra = {
|
||||
BlueprintDescription: 'Test description',
|
||||
workflowRendererVersion: 'Vue'
|
||||
}
|
||||
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => {
|
||||
const serializedSubgraph = fromPartial<ExportedSubgraph>({
|
||||
...subgraph.serialize(),
|
||||
links: [],
|
||||
groups: [],
|
||||
version: 1
|
||||
})
|
||||
return {
|
||||
nodes: [subgraphNode.serialize()],
|
||||
subgraphs: [serializedSubgraph]
|
||||
}
|
||||
})
|
||||
|
||||
let savedWorkflowData: Record<string, unknown> | null = null
|
||||
vi.mocked(api.storeUserData).mockImplementation(async (_path, data) => {
|
||||
savedWorkflowData = JSON.parse(data as string)
|
||||
return {
|
||||
status: 200,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
path: 'subgraphs/testname.json',
|
||||
modified: Date.now(),
|
||||
size: 2
|
||||
})
|
||||
} as Response
|
||||
})
|
||||
|
||||
await mockFetch({ 'testname.json': mockGraph })
|
||||
await store.publishSubgraph()
|
||||
|
||||
expect(savedWorkflowData).not.toBeNull()
|
||||
expect(savedWorkflowData!.extra).toEqual({
|
||||
BlueprintDescription: 'Test description'
|
||||
})
|
||||
const definitions = savedWorkflowData!.definitions as {
|
||||
subgraphs: Array<{ extra?: Record<string, unknown> }>
|
||||
}
|
||||
expect(definitions.subgraphs[0]?.extra?.workflowRendererVersion).toBe(
|
||||
'Vue'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('subgraph definition category', () => {
|
||||
|
||||
@@ -25,15 +25,13 @@ export enum ServerFeatureFlag {
|
||||
COMFYHUB_UPLOAD_ENABLED = 'comfyhub_upload_enabled',
|
||||
COMFYHUB_PROFILE_GATE_ENABLED = 'comfyhub_profile_gate_enabled',
|
||||
SHOW_SIGNIN_BUTTON = 'show_signin_button',
|
||||
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth',
|
||||
CONSOLIDATED_BILLING_ENABLED = 'consolidated_billing_enabled'
|
||||
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth'
|
||||
}
|
||||
|
||||
export function useFeatureFlags() {
|
||||
return {
|
||||
flags: {
|
||||
teamWorkspacesEnabled: true,
|
||||
consolidatedBillingEnabled: true
|
||||
teamWorkspacesEnabled: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,10 +8,12 @@ import {
|
||||
hexToRgb,
|
||||
hsbToRgb,
|
||||
hsvaToHex,
|
||||
isColorFormat,
|
||||
isTransparent,
|
||||
luminance,
|
||||
parseToRgb,
|
||||
readableTextColor,
|
||||
rgbToHsl,
|
||||
rgbToHex,
|
||||
textOnColor,
|
||||
toHexFromFormat
|
||||
@@ -236,6 +238,24 @@ describe('colorUtil conversions', () => {
|
||||
it('parses 8-digit hex and ignores the alpha channel in RGB output', () => {
|
||||
expect(parseToRgb('#ff000080')).toEqual({ r: 255, g: 0, b: 0 })
|
||||
})
|
||||
|
||||
it('covers HSL hue sectors outside primary colors', () => {
|
||||
expect(parseToRgb('hsl(60, 100%, 50%)')).toEqual({
|
||||
r: 255,
|
||||
g: 255,
|
||||
b: 0
|
||||
})
|
||||
expect(parseToRgb('hsl(180, 100%, 50%)')).toEqual({
|
||||
r: 0,
|
||||
g: 255,
|
||||
b: 255
|
||||
})
|
||||
expect(parseToRgb('hsl(300, 100%, 50%)')).toEqual({
|
||||
r: 255,
|
||||
g: 0,
|
||||
b: 255
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('hsbToRgb normalization', () => {
|
||||
@@ -247,6 +267,24 @@ describe('colorUtil conversions', () => {
|
||||
b: 255
|
||||
})
|
||||
})
|
||||
|
||||
it('covers all hue ranges', () => {
|
||||
expect(hsbToRgb({ h: 90, s: 100, b: 100 })).toEqual({
|
||||
r: 127,
|
||||
g: 255,
|
||||
b: 0
|
||||
})
|
||||
expect(hsbToRgb({ h: 180, s: 100, b: 100 })).toEqual({
|
||||
r: 0,
|
||||
g: 255,
|
||||
b: 255
|
||||
})
|
||||
expect(hsbToRgb({ h: 300, s: 100, b: 100 })).toEqual({
|
||||
r: 255,
|
||||
g: 0,
|
||||
b: 255
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isTransparent', () => {
|
||||
@@ -263,12 +301,27 @@ describe('colorUtil conversions', () => {
|
||||
})
|
||||
|
||||
it('returns false for fully opaque hex colors', () => {
|
||||
expect(isTransparent('red')).toBe(false)
|
||||
expect(isTransparent('#ff0000')).toBe(false)
|
||||
expect(isTransparent('#ff0000ff')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('toHexFromFormat', () => {
|
||||
it('normalizes supported hex spellings', () => {
|
||||
expect(toHexFromFormat('', 'hex')).toBe('#000000')
|
||||
expect(toHexFromFormat('abc', 'hex')).toBe('#abc')
|
||||
expect(toHexFromFormat('#abc', 'hex')).toBe('#abc')
|
||||
expect(toHexFromFormat('#abcdef', 'hex')).toBe('#abcdef')
|
||||
expect(toHexFromFormat('abcdef12', 'hex')).toBe('#abcdef12')
|
||||
expect(toHexFromFormat('#abcdef12', 'hex')).toBe('#abcdef12')
|
||||
})
|
||||
|
||||
it('converts rgb strings and rejects non-string rgb input', () => {
|
||||
expect(toHexFromFormat('rgb(255, 128, 0)', 'rgb')).toBe('#ff8000')
|
||||
expect(toHexFromFormat({ r: 255, g: 128, b: 0 }, 'rgb')).toBe('#000000')
|
||||
})
|
||||
|
||||
it('treats an HSV object (with v field) the same as an HSB object', () => {
|
||||
const hsbObject = { h: 120, s: 100, b: 100 }
|
||||
const hsvObject = { h: 120, s: 100, v: 100 }
|
||||
@@ -281,6 +334,7 @@ describe('colorUtil conversions', () => {
|
||||
|
||||
it('returns #000000 for unparseable hsb input', () => {
|
||||
expect(toHexFromFormat({ h: 0 }, 'hsb')).toBe('#000000')
|
||||
expect(toHexFromFormat('hsb()', 'hsb')).toBe('#000000')
|
||||
})
|
||||
|
||||
it('prefixes a bare 6-digit hex with #', () => {
|
||||
@@ -295,6 +349,22 @@ describe('colorUtil conversions', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('rgbToHsl', () => {
|
||||
it('handles light colors and wrapped red hue', () => {
|
||||
expect(rgbToHsl({ r: 255, g: 128, b: 128 }).s).toBeCloseTo(1)
|
||||
expect(rgbToHsl({ r: 255, g: 0, b: 128 }).h).toBeCloseTo(0.916, 2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isColorFormat', () => {
|
||||
it('recognizes public color format ids', () => {
|
||||
expect(isColorFormat('hex')).toBe(true)
|
||||
expect(isColorFormat('rgb')).toBe(true)
|
||||
expect(isColorFormat('hsb')).toBe(true)
|
||||
expect(isColorFormat('hsl')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('readableTextColor / textOnColor', () => {
|
||||
it('lightens dark colors', () => {
|
||||
expect(readableTextColor('#000000')).not.toBe('rgb(0,0,0)')
|
||||
@@ -359,6 +429,10 @@ describe('colorUtil - adjustColor', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('returns the original value when no adjustments are requested', () => {
|
||||
expect(adjustColor('#123456', {})).toBe('#123456')
|
||||
})
|
||||
|
||||
it('treats 5-char hex as valid color with alpha', () => {
|
||||
const result = adjustColor('#f008', {
|
||||
lightness: targetLightness,
|
||||
|
||||
56
src/utils/createAnnotatedPath.test.ts
Normal file
56
src/utils/createAnnotatedPath.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { createAnnotatedPath } from './createAnnotatedPath'
|
||||
|
||||
describe('createAnnotatedPath', () => {
|
||||
it('returns bare input paths for string items by default', () => {
|
||||
expect(createAnnotatedPath('image.png')).toBe('image.png')
|
||||
})
|
||||
|
||||
it('prepends the supplied subfolder for string items', () => {
|
||||
expect(createAnnotatedPath('image.png', { subfolder: 'uploads' })).toBe(
|
||||
'uploads/image.png'
|
||||
)
|
||||
})
|
||||
|
||||
it('annotates string items when the root folder is not input', () => {
|
||||
expect(createAnnotatedPath('image.png', { rootFolder: 'output' })).toBe(
|
||||
'image.png [output]'
|
||||
)
|
||||
})
|
||||
|
||||
it('does not duplicate an existing annotation', () => {
|
||||
expect(
|
||||
createAnnotatedPath('image.png [temp]', { rootFolder: 'temp' })
|
||||
).toBe('image.png [temp]')
|
||||
})
|
||||
|
||||
it('formats result items with their own subfolder', () => {
|
||||
expect(
|
||||
createAnnotatedPath({
|
||||
filename: 'render.png',
|
||||
subfolder: 'final',
|
||||
type: 'output'
|
||||
})
|
||||
).toBe('final/render.png [output]')
|
||||
})
|
||||
|
||||
it('omits the result-item annotation when type matches the root folder', () => {
|
||||
expect(
|
||||
createAnnotatedPath(
|
||||
{
|
||||
filename: 'render.png',
|
||||
subfolder: '',
|
||||
type: 'output'
|
||||
},
|
||||
{ rootFolder: 'output' }
|
||||
)
|
||||
).toBe('render.png')
|
||||
})
|
||||
|
||||
it('handles missing result-item filenames', () => {
|
||||
expect(createAnnotatedPath({ subfolder: 'folder', type: 'temp' })).toBe(
|
||||
'folder/ [temp]'
|
||||
)
|
||||
})
|
||||
})
|
||||
60
src/utils/envUtil.test.ts
Normal file
60
src/utils/envUtil.test.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const distributionState = vi.hoisted(() => ({
|
||||
isDesktop: false
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isDesktop() {
|
||||
return distributionState.isDesktop
|
||||
}
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
Reflect.deleteProperty(window, 'electronAPI')
|
||||
Object.defineProperty(window.navigator, 'windowControlsOverlay', {
|
||||
configurable: true,
|
||||
value: undefined
|
||||
})
|
||||
})
|
||||
|
||||
async function importEnvUtil(isDesktop: boolean) {
|
||||
distributionState.isDesktop = isDesktop
|
||||
return await import('@/utils/envUtil')
|
||||
}
|
||||
|
||||
describe('envUtil', () => {
|
||||
it('returns and uses the Electron API when present', async () => {
|
||||
const showContextMenu = vi.fn()
|
||||
Object.defineProperty(window, 'electronAPI', {
|
||||
configurable: true,
|
||||
value: { showContextMenu }
|
||||
})
|
||||
const envUtil = await importEnvUtil(true)
|
||||
|
||||
expect(envUtil.electronAPI()).toEqual({ showContextMenu })
|
||||
|
||||
envUtil.showNativeSystemMenu()
|
||||
|
||||
expect(showContextMenu).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('detects native windows only in desktop window-control overlays', async () => {
|
||||
Object.defineProperty(window.navigator, 'windowControlsOverlay', {
|
||||
configurable: true,
|
||||
value: { visible: true }
|
||||
})
|
||||
|
||||
expect((await importEnvUtil(true)).isNativeWindow()).toBe(true)
|
||||
|
||||
vi.resetModules()
|
||||
expect((await importEnvUtil(false)).isNativeWindow()).toBe(false)
|
||||
})
|
||||
|
||||
it('tolerates a missing Electron API for native menu calls', async () => {
|
||||
const envUtil = await importEnvUtil(true)
|
||||
|
||||
expect(() => envUtil.showNativeSystemMenu()).not.toThrow()
|
||||
})
|
||||
})
|
||||
207
src/utils/executionUtil.test.ts
Normal file
207
src/utils/executionUtil.test.ts
Normal file
@@ -0,0 +1,207 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LGraphEventMode } from '@/lib/litegraph/src/litegraph'
|
||||
|
||||
import { graphToPrompt } from './executionUtil'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
compressWidgetInputSlots: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('./litegraphUtil', () => ({
|
||||
compressWidgetInputSlots: mocks.compressWidgetInputSlots
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/litegraph/src/litegraph', () => ({
|
||||
LGraphEventMode: {
|
||||
NEVER: 1,
|
||||
BYPASS: 2
|
||||
},
|
||||
ExecutableNodeDTO: vi.fn(function (node: {
|
||||
id: number | string
|
||||
executionId?: number | string
|
||||
isVirtualNode?: boolean
|
||||
mode?: number
|
||||
widgets?: unknown[]
|
||||
inputs?: unknown[]
|
||||
resolveInput?: (index: number) => unknown
|
||||
comfyClass?: string
|
||||
title?: string
|
||||
dtoInnerNodes?: unknown[]
|
||||
}) {
|
||||
return {
|
||||
id: node.executionId ?? node.id,
|
||||
isVirtualNode: node.isVirtualNode ?? false,
|
||||
mode: node.mode ?? 0,
|
||||
widgets: node.widgets,
|
||||
inputs: node.inputs ?? [],
|
||||
resolveInput: node.resolveInput ?? (() => null),
|
||||
comfyClass: node.comfyClass,
|
||||
title: node.title,
|
||||
getInnerNodes: () => node.dtoInnerNodes ?? []
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
function graphWith(
|
||||
nodes: unknown[],
|
||||
workflowExtra?: Record<string, unknown>,
|
||||
workflowNodes: Array<Record<string, unknown>> = [
|
||||
{
|
||||
inputs: [{ name: 'in', localized_name: 'Input' }],
|
||||
outputs: [{ name: 'out', localized_name: 'Output' }]
|
||||
}
|
||||
]
|
||||
) {
|
||||
return {
|
||||
computeExecutionOrder: vi.fn(() => nodes),
|
||||
serialize: vi.fn(() => ({
|
||||
nodes: workflowNodes,
|
||||
extra: workflowExtra
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
describe('graphToPrompt', () => {
|
||||
it('serializes widget values, links, virtual setup, and workflow metadata', async () => {
|
||||
const virtualApply = vi.fn()
|
||||
const virtualInner = {
|
||||
id: 'virtual-inner',
|
||||
isVirtualNode: true,
|
||||
applyToGraph: virtualApply
|
||||
}
|
||||
const innerOutputNode = {
|
||||
id: 'inner-output',
|
||||
inputs: [],
|
||||
widgets: [],
|
||||
comfyClass: 'InnerClass',
|
||||
title: 'Inner'
|
||||
}
|
||||
const node = {
|
||||
id: 1,
|
||||
getInnerNodes: vi.fn(() => [virtualInner]),
|
||||
dtoInnerNodes: [innerOutputNode],
|
||||
inputs: [
|
||||
{ name: 'missing' },
|
||||
{ name: 'widget-array' },
|
||||
{ name: 'link' },
|
||||
{ name: 'removed' }
|
||||
],
|
||||
widgets: [
|
||||
{ name: '', value: 'ignored' },
|
||||
{ name: 'skipped', value: 'ignored', options: { serialize: false } },
|
||||
{
|
||||
name: 'curve',
|
||||
type: 'curve',
|
||||
serializeValue: vi.fn(async () => [1, 2])
|
||||
},
|
||||
{ name: 'array', value: [3, 4] },
|
||||
{ name: 'plain', value: 'value' }
|
||||
],
|
||||
resolveInput: vi.fn((index: number) => {
|
||||
if (index === 1) return { widgetInfo: { value: [5, 6] } }
|
||||
if (index === 2) return { origin_id: 'inner-output', origin_slot: '7' }
|
||||
if (index === 3) return { origin_id: 'removed-node', origin_slot: '1' }
|
||||
return null
|
||||
}),
|
||||
comfyClass: 'TestClass',
|
||||
title: 'Test'
|
||||
}
|
||||
const graph = graphWith([node])
|
||||
|
||||
const { workflow, output } = await graphToPrompt(
|
||||
graph as unknown as Parameters<typeof graphToPrompt>[0],
|
||||
{ sortNodes: true }
|
||||
)
|
||||
|
||||
expect(virtualApply).toHaveBeenCalledTimes(1)
|
||||
expect(graph.serialize).toHaveBeenCalledWith({ sortNodes: true })
|
||||
expect(mocks.compressWidgetInputSlots).toHaveBeenCalledWith(workflow)
|
||||
expect(workflow.nodes[0].inputs?.[0]).toEqual({ name: 'in' })
|
||||
expect(workflow.nodes[0].outputs?.[0]).toEqual({ name: 'out' })
|
||||
expect(workflow.extra?.frontendVersion).toBeDefined()
|
||||
expect(output['1']).toEqual({
|
||||
inputs: {
|
||||
curve: { __type__: 'CURVE', __value__: [1, 2] },
|
||||
array: { __value__: [3, 4] },
|
||||
plain: 'value',
|
||||
'widget-array': { __value__: [5, 6] },
|
||||
link: ['inner-output', 7]
|
||||
},
|
||||
class_type: 'TestClass',
|
||||
_meta: { title: 'Test' }
|
||||
})
|
||||
expect(output['inner-output']).toEqual({
|
||||
inputs: {},
|
||||
class_type: 'InnerClass',
|
||||
_meta: { title: 'Inner' }
|
||||
})
|
||||
})
|
||||
|
||||
it('skips muted and virtual executable nodes', async () => {
|
||||
const normalNode = {
|
||||
id: 'normal',
|
||||
inputs: [],
|
||||
comfyClass: 'Normal',
|
||||
title: 'Normal'
|
||||
}
|
||||
const mutedNode = {
|
||||
id: 'muted',
|
||||
mode: LGraphEventMode.NEVER,
|
||||
inputs: [],
|
||||
widgets: [{ name: 'value', value: 'ignored' }],
|
||||
comfyClass: 'Muted',
|
||||
title: 'Muted',
|
||||
dtoInnerNodes: [
|
||||
{
|
||||
id: 'muted-inner',
|
||||
inputs: [],
|
||||
comfyClass: 'MutedInner',
|
||||
title: 'MutedInner'
|
||||
}
|
||||
]
|
||||
}
|
||||
const bypassedNode = {
|
||||
id: 'bypassed',
|
||||
mode: LGraphEventMode.BYPASS,
|
||||
inputs: [],
|
||||
comfyClass: 'Bypassed',
|
||||
title: 'Bypassed'
|
||||
}
|
||||
const virtualNode = {
|
||||
id: 'virtual',
|
||||
isVirtualNode: true,
|
||||
inputs: [],
|
||||
comfyClass: 'Virtual',
|
||||
title: 'Virtual'
|
||||
}
|
||||
const graph = graphWith(
|
||||
[normalNode, mutedNode, bypassedNode, virtualNode],
|
||||
{}
|
||||
)
|
||||
|
||||
const { workflow, output } = await graphToPrompt(
|
||||
graph as unknown as Parameters<typeof graphToPrompt>[0]
|
||||
)
|
||||
|
||||
expect(graph.serialize).toHaveBeenCalledWith({ sortNodes: false })
|
||||
expect(workflow.extra?.frontendVersion).toBeDefined()
|
||||
expect(Object.keys(output)).toEqual(['normal'])
|
||||
})
|
||||
|
||||
it('preserves serialized workflow nodes without slot arrays', async () => {
|
||||
const node = {
|
||||
id: 1,
|
||||
inputs: [],
|
||||
comfyClass: 'NodeClass',
|
||||
title: 'Node'
|
||||
}
|
||||
const graph = graphWith([node], {}, [{ id: 1 }])
|
||||
|
||||
const { workflow } = await graphToPrompt(
|
||||
graph as unknown as Parameters<typeof graphToPrompt>[0]
|
||||
)
|
||||
|
||||
expect(workflow.nodes[0]).toEqual({ id: 1 })
|
||||
})
|
||||
})
|
||||
@@ -11,8 +11,8 @@ interface FilterItem {
|
||||
options: string[]
|
||||
}
|
||||
|
||||
const makeSearch = <T>(data: T[] = []) =>
|
||||
new FuseSearch<T>(data, {
|
||||
function makeSearch<T>(data: T[] = []) {
|
||||
return new FuseSearch<T>(data, {
|
||||
fuseOptions: {
|
||||
keys: ['name'],
|
||||
includeScore: true,
|
||||
@@ -21,6 +21,7 @@ const makeSearch = <T>(data: T[] = []) =>
|
||||
},
|
||||
advancedScoring: true
|
||||
})
|
||||
}
|
||||
|
||||
describe('FuseSearch', () => {
|
||||
it('assigns stable ranking tiers for exact, prefix, word, substring, and multi-part matches', () => {
|
||||
|
||||
@@ -9,21 +9,27 @@ import type {
|
||||
import {
|
||||
collectAllNodes,
|
||||
collectFromNodes,
|
||||
executionIdToNodeLocatorId,
|
||||
findNodeInHierarchy,
|
||||
findSubgraphByUuid,
|
||||
findSubgraphPathById,
|
||||
getActiveGraphNodeIds,
|
||||
forEachNode,
|
||||
forEachSubgraphNode,
|
||||
getAllNonIoNodesInSubgraph,
|
||||
getExecutionIdsForSelectedNodes,
|
||||
getLocalNodeIdFromExecutionId,
|
||||
getLocatorIdFromNodeData,
|
||||
getNodeByExecutionId,
|
||||
getNodeByLocatorId,
|
||||
getRootGraph,
|
||||
getRootParentNode,
|
||||
getSubgraphPathFromExecutionId,
|
||||
getExecutionIdFromNodeData,
|
||||
mapAllNodes,
|
||||
mapSubgraphNodes,
|
||||
parseExecutionId,
|
||||
reduceAllNodes,
|
||||
traverseNodesDepthFirst,
|
||||
traverseSubgraphPath,
|
||||
triggerCallbackOnAllNodes,
|
||||
@@ -36,6 +42,7 @@ import {
|
||||
isMissingCandidateActive
|
||||
} from '@/utils/graphTraversalUtil'
|
||||
import { LGraphEventMode } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import { createNodeExecutionId } from '@/types/nodeIdentification'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
import { createMockLGraphNode } from './__tests__/litegraphTestUtils'
|
||||
@@ -97,6 +104,8 @@ describe('graphTraversalUtil', () => {
|
||||
expect(findNodeInHierarchy(graph, '')).toBeNull()
|
||||
expect(getExecutionIdForNodeInGraph(graph, graph, '')).toBeNull()
|
||||
expect(getExecutionIdFromNodeData(graph, { id: '' })).toBeNull()
|
||||
expect(getLocatorIdFromNodeData({ id: '' })).toBeNull()
|
||||
expect(executionIdToNodeLocatorId(graph, '')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -215,6 +224,12 @@ describe('graphTraversalUtil', () => {
|
||||
const result = traverseSubgraphPath(graph, ['999'])
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null for an unparseable path segment', () => {
|
||||
const graph = createMockGraph([])
|
||||
|
||||
expect(traverseSubgraphPath(graph, [''])).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -501,6 +516,17 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
|
||||
describe('findSubgraphByUuid', () => {
|
||||
it('uses the root graph subgraph registry when available', () => {
|
||||
const subgraph = createMockSubgraph('registered-uuid', [])
|
||||
const graph = {
|
||||
...createMockGraph([]),
|
||||
subgraphs: new Map([[subgraph.id, subgraph]])
|
||||
} satisfies Partial<LGraph> as LGraph
|
||||
|
||||
expect(findSubgraphByUuid(graph, subgraph.id)).toBe(subgraph)
|
||||
expect(findSubgraphByUuid(graph, 'missing-uuid')).toBeNull()
|
||||
})
|
||||
|
||||
it('should find subgraph by UUID', () => {
|
||||
const targetUuid = 'target-uuid'
|
||||
const subgraph = createMockSubgraph(targetUuid, [])
|
||||
@@ -546,6 +572,60 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('findSubgraphPathById', () => {
|
||||
it('returns the path to a nested subgraph', () => {
|
||||
const targetSubgraph = createMockSubgraph('target-uuid', [])
|
||||
const middleNode = createMockNode('20', {
|
||||
isSubgraph: true,
|
||||
subgraph: targetSubgraph
|
||||
})
|
||||
const middleSubgraph = createMockSubgraph('middle-uuid', [middleNode])
|
||||
const rootNode = createMockNode('10', {
|
||||
isSubgraph: true,
|
||||
subgraph: middleSubgraph
|
||||
})
|
||||
const graph = createMockGraph([rootNode])
|
||||
|
||||
expect(findSubgraphPathById(graph, 'target-uuid')).toEqual([
|
||||
'middle-uuid',
|
||||
'target-uuid'
|
||||
])
|
||||
})
|
||||
|
||||
it('skips malformed graph entries while searching', () => {
|
||||
const malformedSubgraph = {
|
||||
id: 'malformed',
|
||||
nodes: []
|
||||
} as unknown as Subgraph
|
||||
const graph = createMockGraph([
|
||||
createMockNode('10', {
|
||||
isSubgraph: true,
|
||||
subgraph: malformedSubgraph
|
||||
})
|
||||
])
|
||||
|
||||
expect(findSubgraphPathById(graph, 'missing-uuid')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRootParentNode', () => {
|
||||
it('returns null for root-level or invalid execution IDs', () => {
|
||||
const graph = createMockGraph([createMockNode('10')])
|
||||
|
||||
expect(getRootParentNode(graph, '10')).toBeNull()
|
||||
expect(getRootParentNode(graph, '')).toBeNull()
|
||||
expect(getRootParentNode(graph, 'invalid:20')).toBeNull()
|
||||
})
|
||||
|
||||
it('returns the root parent for a nested execution ID', () => {
|
||||
const parent = createMockNode('10')
|
||||
const graph = createMockGraph([parent])
|
||||
|
||||
expect(getRootParentNode(graph, '10:20:30')).toBe(parent)
|
||||
expect(getRootParentNode(graph, '10:20:30')?.id).toBe('10')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getNodeByExecutionId', () => {
|
||||
it('should find node in root graph', () => {
|
||||
const nodes = [createMockNode('123'), createMockNode('456')]
|
||||
@@ -616,6 +696,14 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
|
||||
describe('getExecutionIdByNode', () => {
|
||||
it('returns null when the node has no graph', () => {
|
||||
const node = createMockNode('123')
|
||||
node.graph = null as unknown as LGraph
|
||||
const graph = createMockGraph([])
|
||||
|
||||
expect(getExecutionIdByNode(graph, node)).toBeNull()
|
||||
})
|
||||
|
||||
it('should return node id if graph is rootGraph', () => {
|
||||
const node = createMockNode('123')
|
||||
const graph = createMockGraph([node])
|
||||
@@ -736,6 +824,21 @@ describe('graphTraversalUtil', () => {
|
||||
getExecutionIdForNodeInGraph(rootGraph, orphanSubgraph, '63')
|
||||
).toBe('63')
|
||||
})
|
||||
|
||||
it('falls back to local id when the parent path is not parseable', () => {
|
||||
const interior = createMockNode('63')
|
||||
const subgraph = createMockSubgraph('sub-uuid', [interior])
|
||||
const subgraphNode = createMockLGraphNode({
|
||||
id: toNodeId(''),
|
||||
isSubgraphNode: () => true,
|
||||
subgraph
|
||||
}) satisfies Partial<LGraphNode> as LGraphNode
|
||||
const rootGraph = createMockGraph([subgraphNode])
|
||||
|
||||
expect(getExecutionIdForNodeInGraph(rootGraph, subgraph, '63')).toBe(
|
||||
'63'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isAncestorPathActive', () => {
|
||||
@@ -808,6 +911,22 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
|
||||
describe('isExecutionPathActive', () => {
|
||||
it('returns true when the graph is unavailable', () => {
|
||||
expect(isExecutionPathActive(null, '42')).toBe(true)
|
||||
expect(isExecutionPathActive(undefined, '42')).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false when the target node cannot be resolved', () => {
|
||||
expect(isExecutionPathActive(createMockGraph([]), '42')).toBe(false)
|
||||
})
|
||||
|
||||
it('returns true for an active target with active ancestors', () => {
|
||||
const node = createMockNode('42')
|
||||
const rootGraph = createMockGraph([node])
|
||||
|
||||
expect(isExecutionPathActive(rootGraph, '42')).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false when the target node itself is bypassed', () => {
|
||||
const node = createMockLGraphNode({
|
||||
id: 42,
|
||||
@@ -914,6 +1033,10 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
|
||||
describe('isCandidateScopeActive', () => {
|
||||
it('treats candidates without a scoped node as active', () => {
|
||||
expect(isCandidateScopeActive(createMockGraph([]), {})).toBe(true)
|
||||
})
|
||||
|
||||
it('uses sourceExecutionId before nodeId', () => {
|
||||
const rootNode = createMockNode('65')
|
||||
const sourceNode = createMockLGraphNode({
|
||||
@@ -1018,6 +1141,51 @@ describe('graphTraversalUtil', () => {
|
||||
const found = getNodeByLocatorId(graph, locatorId)
|
||||
expect(found).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null when the root graph is unavailable', () => {
|
||||
expect(getNodeByLocatorId(null as unknown as LGraph, '123')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('executionIdToNodeLocatorId', () => {
|
||||
it('returns a root locator for root execution IDs', () => {
|
||||
const graph = createMockGraph([])
|
||||
|
||||
expect(executionIdToNodeLocatorId(graph, '123')).toBe('123')
|
||||
})
|
||||
|
||||
it('returns a subgraph locator for nested execution IDs', () => {
|
||||
const targetNode = createMockNode('789')
|
||||
const subgraph = createMockSubgraph(
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
|
||||
[targetNode]
|
||||
)
|
||||
const graph = createMockGraph([
|
||||
createMockNode('456', { isSubgraph: true, subgraph })
|
||||
])
|
||||
|
||||
expect(executionIdToNodeLocatorId(graph, '456:789')).toBe(
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890:789'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns undefined when the nested path cannot be resolved', () => {
|
||||
const graph = createMockGraph([createMockNode('456')])
|
||||
|
||||
expect(executionIdToNodeLocatorId(graph, '456:789')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when the nested local id is invalid', () => {
|
||||
const subgraph = createMockSubgraph(
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
|
||||
[]
|
||||
)
|
||||
const graph = createMockGraph([
|
||||
createMockNode('456', { isSubgraph: true, subgraph })
|
||||
])
|
||||
|
||||
expect(executionIdToNodeLocatorId(graph, '456:')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRootGraph', () => {
|
||||
@@ -1207,7 +1375,27 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('reduceAllNodes', () => {
|
||||
it('reduces nodes across nested subgraphs', () => {
|
||||
const subgraph = createMockSubgraph('sub-uuid', [createMockNode('10')])
|
||||
const graph = createMockGraph([
|
||||
createMockNode('1'),
|
||||
createMockNode('2', { isSubgraph: true, subgraph })
|
||||
])
|
||||
|
||||
expect(
|
||||
reduceAllNodes(graph, (sum, node) => sum + Number(node.id), 0)
|
||||
).toBe(13)
|
||||
})
|
||||
})
|
||||
|
||||
describe('traverseNodesDepthFirst', () => {
|
||||
it('handles missing traversal options', () => {
|
||||
const node = createMockNode('1')
|
||||
|
||||
expect(() => traverseNodesDepthFirst([node])).not.toThrow()
|
||||
})
|
||||
|
||||
it('should traverse nodes in depth-first order', () => {
|
||||
const visited: string[] = []
|
||||
const nodes = [
|
||||
@@ -1299,6 +1487,15 @@ describe('graphTraversalUtil', () => {
|
||||
})
|
||||
|
||||
describe('collectFromNodes', () => {
|
||||
it('collects nodes with default options', () => {
|
||||
const nodes = [createMockNode('1'), createMockNode('2')]
|
||||
|
||||
expect(collectFromNodes(nodes).map((node) => String(node.id))).toEqual([
|
||||
'2',
|
||||
'1'
|
||||
])
|
||||
})
|
||||
|
||||
it('should collect data from all nodes', () => {
|
||||
const nodes = [
|
||||
createMockNode('1'),
|
||||
@@ -1583,6 +1780,44 @@ describe('graphTraversalUtil', () => {
|
||||
|
||||
expect(executionIds).toEqual(['2:10', '2:11'])
|
||||
})
|
||||
|
||||
it('returns an empty list when the starting subgraph is unreachable', () => {
|
||||
const rootGraph = createMockGraph([])
|
||||
const orphanSubgraph = createMockSubgraph('orphan', [], rootGraph)
|
||||
createMockNode('10', { graph: orphanSubgraph })
|
||||
|
||||
expect(
|
||||
getExecutionIdsForSelectedNodes(orphanSubgraph.nodes, orphanSubgraph)
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getActiveGraphNodeIds', () => {
|
||||
it('returns local ids for execution ids in the active graph', () => {
|
||||
const rootNode = createMockNode('1')
|
||||
const subNode = createMockNode('10')
|
||||
const subgraph = createMockSubgraph('sub-uuid', [subNode])
|
||||
const subgraphNode = createMockNode('2', {
|
||||
isSubgraph: true,
|
||||
subgraph
|
||||
})
|
||||
const graph = createMockGraph([rootNode, subgraphNode])
|
||||
rootNode.graph = graph
|
||||
subgraphNode.graph = graph
|
||||
subNode.graph = subgraph
|
||||
const executionIds = new Set([
|
||||
createNodeExecutionId([toNodeId(1)]),
|
||||
createNodeExecutionId([toNodeId(2), toNodeId(10)]),
|
||||
createNodeExecutionId([toNodeId(999)])
|
||||
])
|
||||
|
||||
expect(getActiveGraphNodeIds(graph, graph, executionIds)).toEqual(
|
||||
new Set(['1'])
|
||||
)
|
||||
expect(getActiveGraphNodeIds(graph, subgraph, executionIds)).toEqual(
|
||||
new Set(['10'])
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
44
src/utils/gridUtil.test.ts
Normal file
44
src/utils/gridUtil.test.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createGridStyle } from '@/utils/gridUtil'
|
||||
|
||||
describe('createGridStyle', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('uses auto-fill columns by default', () => {
|
||||
expect(createGridStyle()).toEqual({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(auto-fill, minmax(15rem, 1fr))',
|
||||
padding: '0',
|
||||
gap: '1rem'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses fixed columns when provided', () => {
|
||||
expect(
|
||||
createGridStyle({
|
||||
columns: 3,
|
||||
padding: '8px',
|
||||
gap: '4px'
|
||||
})
|
||||
).toEqual({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, 1fr)',
|
||||
padding: '8px',
|
||||
gap: '4px'
|
||||
})
|
||||
})
|
||||
|
||||
it('warns and clamps invalid fixed columns', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
|
||||
expect(createGridStyle({ columns: -1 }).gridTemplateColumns).toBe(
|
||||
'repeat(1, 1fr)'
|
||||
)
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'createGridStyle: columns must be >= 1, defaulting to 1'
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,11 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { getGridThumbnailUrl, parseImageWidgetValue } from './imageUtil'
|
||||
import {
|
||||
fitDimensionsToNodeWidth,
|
||||
getGridThumbnailUrl,
|
||||
is_all_same_aspect_ratio,
|
||||
parseImageWidgetValue
|
||||
} from './imageUtil'
|
||||
|
||||
describe('getGridThumbnailUrl', () => {
|
||||
it('adds a compact preview format to a full-resolution view URL', () => {
|
||||
@@ -28,6 +33,20 @@ describe('getGridThumbnailUrl', () => {
|
||||
const blob = 'blob:http://localhost/abc-123'
|
||||
expect(getGridThumbnailUrl(blob)).toBe(blob)
|
||||
})
|
||||
|
||||
it('leaves empty and malformed URLs untouched', () => {
|
||||
expect(getGridThumbnailUrl('')).toBe('')
|
||||
expect(getGridThumbnailUrl('http://[bad-url')).toBe('http://[bad-url')
|
||||
})
|
||||
|
||||
it('preserves absolute URL shape when adding preview params', () => {
|
||||
const result = getGridThumbnailUrl(
|
||||
'https://comfy.local/api/view?filename=image.png&type=output'
|
||||
)
|
||||
|
||||
expect(result).toMatch(/^https:\/\/comfy\.local\/api\/view\?/)
|
||||
expect(result).toContain('preview=webp')
|
||||
})
|
||||
})
|
||||
|
||||
describe('parseImageWidgetValue', () => {
|
||||
@@ -83,3 +102,52 @@ describe('parseImageWidgetValue', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('is_all_same_aspect_ratio', () => {
|
||||
function image(width: number, height: number) {
|
||||
const img = document.createElement('img')
|
||||
Object.defineProperty(img, 'naturalWidth', {
|
||||
configurable: true,
|
||||
value: width
|
||||
})
|
||||
Object.defineProperty(img, 'naturalHeight', {
|
||||
configurable: true,
|
||||
value: height
|
||||
})
|
||||
return img
|
||||
}
|
||||
|
||||
it('accepts empty, single, and matching aspect ratio image sets', () => {
|
||||
expect(is_all_same_aspect_ratio([])).toBe(true)
|
||||
expect(is_all_same_aspect_ratio([image(10, 20)])).toBe(true)
|
||||
expect(is_all_same_aspect_ratio([image(10, 20), image(30, 60)])).toBe(true)
|
||||
})
|
||||
|
||||
it('rejects mismatched aspect ratios', () => {
|
||||
expect(is_all_same_aspect_ratio([image(10, 20), image(20, 20)])).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('fitDimensionsToNodeWidth', () => {
|
||||
it('scales image dimensions to the node width with a minimum height', () => {
|
||||
expect(fitDimensionsToNodeWidth(400, 200, 100)).toEqual({
|
||||
minWidth: 100,
|
||||
minHeight: 64
|
||||
})
|
||||
expect(fitDimensionsToNodeWidth(200, 400, 100)).toEqual({
|
||||
minWidth: 100,
|
||||
minHeight: 200
|
||||
})
|
||||
})
|
||||
|
||||
it('returns zero dimensions when the aspect ratio is zero or NaN', () => {
|
||||
expect(fitDimensionsToNodeWidth(0, 100, 200)).toEqual({
|
||||
minWidth: 0,
|
||||
minHeight: 0
|
||||
})
|
||||
expect(fitDimensionsToNodeWidth(0, 0, 200)).toEqual({
|
||||
minWidth: 0,
|
||||
minHeight: 0
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -162,6 +162,26 @@ describe('fixBadLinks', () => {
|
||||
expect(graph.nodes[1]?.inputs?.[0]?.link).toBe(1)
|
||||
})
|
||||
|
||||
it('reports a missing target input link during a dry run', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [
|
||||
createNode({ id: 1, outputs: [createOutput([1])] }),
|
||||
createNode({ id: 2, inputs: [createInput(null)] })
|
||||
],
|
||||
links: [[1, 1, 0, 2, 0, '*']]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph)
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: true,
|
||||
fixed: false,
|
||||
patched: 1,
|
||||
deleted: 0
|
||||
})
|
||||
expect(graph.nodes[1]?.inputs?.[0]?.link).toBeNull()
|
||||
})
|
||||
|
||||
it('removes the origin reference when the target input slot is missing', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [
|
||||
@@ -206,6 +226,53 @@ describe('fixBadLinks', () => {
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('keeps the later target link when two links target the same input slot', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [
|
||||
createNode({ id: 1, outputs: [createOutput([1])] }),
|
||||
createNode({ id: 2, outputs: [createOutput([2])] }),
|
||||
createNode({ id: 3, inputs: [createInput(null)] })
|
||||
],
|
||||
links: [
|
||||
[1, 1, 0, 3, 0, '*'],
|
||||
[2, 2, 0, 3, 0, '*']
|
||||
]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.nodes[0]?.outputs?.[0]?.links).toEqual([])
|
||||
expect(graph.nodes[1]?.outputs?.[0]?.links).toEqual([2])
|
||||
expect(graph.nodes[2]?.inputs?.[0]?.link).toBe(2)
|
||||
expect(graph.links).toEqual([[2, 2, 0, 3, 0, '*']])
|
||||
})
|
||||
|
||||
it('reports stale origin references during a dry run', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [
|
||||
createNode({ id: 1, outputs: [createOutput([1])] }),
|
||||
createNode({ id: 2, inputs: [createInput(2)] })
|
||||
],
|
||||
links: [[1, 1, 0, 2, 0, '*']]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph)
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: true,
|
||||
fixed: false,
|
||||
patched: 1,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.nodes[0]?.outputs?.[0]?.links).toEqual([1])
|
||||
expect(graph.links).toEqual([[1, 1, 0, 2, 0, '*']])
|
||||
})
|
||||
|
||||
it('cleans dangling references when a linked node is missing', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [createNode({ id: 2, inputs: [createInput(1)] })],
|
||||
@@ -224,6 +291,24 @@ describe('fixBadLinks', () => {
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('deletes missing-origin links when the target does not reference them', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [createNode({ id: 2, inputs: [createInput(null)] })],
|
||||
links: [[1, 1, 0, 2, 0, '*']]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
patched: 0,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.nodes[0]?.inputs?.[0]?.link).toBeNull()
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('cleans dangling origin references when the target node is missing', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [createNode({ id: 1, outputs: [createOutput([1])] })],
|
||||
@@ -242,6 +327,24 @@ describe('fixBadLinks', () => {
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('deletes missing-target links when the origin does not reference them', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [createNode({ id: 1, outputs: [createOutput([])] })],
|
||||
links: [[1, 1, 0, 2, 0, '*']]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
patched: 0,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.nodes[0]?.outputs?.[0]?.links).toEqual([])
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('deletes a stale link that neither endpoint references', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [
|
||||
@@ -317,4 +420,118 @@ describe('fixBadLinks', () => {
|
||||
expect(graph.nodes[0]?.outputs?.[0]?.links).toEqual([1])
|
||||
expect(logger.log).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('creates missing origin output slots in fix mode', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [
|
||||
createNode({ id: 1 }),
|
||||
createNode({ id: 2, inputs: [createInput(1)] })
|
||||
],
|
||||
links: [[1, 1, 0, 2, 0, '*']]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
patched: 1,
|
||||
deleted: 0
|
||||
})
|
||||
expect(graph.nodes[0]?.outputs?.[0]?.links).toEqual([1])
|
||||
})
|
||||
|
||||
it('deletes links whose serialized endpoints are both missing', () => {
|
||||
const graph = createGraph({
|
||||
nodes: [],
|
||||
links: [[1, 1, 0, 2, 0, '*']]
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
patched: 0,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('ignores null serialized link entries', () => {
|
||||
const graph = {
|
||||
...createGraph({
|
||||
nodes: [createNode({ id: 1 })],
|
||||
links: []
|
||||
}),
|
||||
links: [null as unknown as SerialisedLLinkArray]
|
||||
}
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: false,
|
||||
patched: 0,
|
||||
deleted: 0
|
||||
})
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('deletes object-shaped serialized links', () => {
|
||||
const graph = {
|
||||
...createGraph({
|
||||
nodes: [],
|
||||
links: []
|
||||
}),
|
||||
links: [
|
||||
{
|
||||
id: 1,
|
||||
origin_id: 1,
|
||||
origin_slot: 0,
|
||||
target_id: 2,
|
||||
target_slot: 0,
|
||||
type: '*'
|
||||
}
|
||||
]
|
||||
} as unknown as ISerialisedGraph
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
patched: 0,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.links).toEqual([])
|
||||
})
|
||||
|
||||
it('treats invalid live graph endpoint ids as missing', () => {
|
||||
const linkId = toLinkId(1)
|
||||
const link = fromPartial<LLink>({
|
||||
id: linkId,
|
||||
origin_id: toNodeId(''),
|
||||
origin_slot: 0,
|
||||
target_id: toNodeId(''),
|
||||
target_slot: 0,
|
||||
type: '*'
|
||||
})
|
||||
const links = new Map([[linkId, link]])
|
||||
const graph = fromAny<LGraph, unknown>({
|
||||
links,
|
||||
getNodeById: vi.fn()
|
||||
})
|
||||
|
||||
const result = fixBadLinks(graph, { fix: true, silent: true })
|
||||
|
||||
expect(result).toMatchObject({
|
||||
hasBadLinks: false,
|
||||
fixed: true,
|
||||
patched: 0,
|
||||
deleted: 1
|
||||
})
|
||||
expect(graph.getNodeById).not.toHaveBeenCalled()
|
||||
expect(links.has(linkId)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,14 +1,56 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
|
||||
import { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import type { LGraphCanvas } from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraph, LGraphNode, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import {
|
||||
LGraph,
|
||||
LGraphGroup,
|
||||
LGraphNode,
|
||||
LiteGraph,
|
||||
Reroute
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import { createTestSubgraph } from '@/lib/litegraph/src/subgraph/__fixtures__/subgraphHelpers'
|
||||
import type {
|
||||
ExportedSubgraph,
|
||||
ISerialisedGraph
|
||||
} from '@/lib/litegraph/src/types/serialisation'
|
||||
import type {
|
||||
IBaseWidget,
|
||||
IComboWidget
|
||||
} from '@/lib/litegraph/src/types/widgets'
|
||||
import { createNodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { toLinkId } from '@/types/linkId'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import { toRerouteId } from '@/types/rerouteId'
|
||||
import { widgetId } from '@/types/widgetId'
|
||||
import { createMockLGraphNode } from '@/utils/__tests__/litegraphTestUtils'
|
||||
|
||||
import { createNode, getWidgetIdForNode, resolveNode } from './litegraphUtil'
|
||||
import {
|
||||
addToComboValues,
|
||||
compressWidgetInputSlots,
|
||||
createNode,
|
||||
executeWidgetsCallback,
|
||||
fixLinkInputSlots,
|
||||
getItemsColorOption,
|
||||
getLinkTypeColor,
|
||||
getWidgetIdForNode,
|
||||
isAnimatedOutput,
|
||||
isAudioNode,
|
||||
isImageNode,
|
||||
isLGraphGroup,
|
||||
isLGraphNode,
|
||||
isLoad3dNode,
|
||||
isReroute,
|
||||
isVideoNode,
|
||||
isVideoOutput,
|
||||
migrateWidgetsValues,
|
||||
resolveComboValues,
|
||||
resolveNode,
|
||||
resolveNodeWidget
|
||||
} from './litegraphUtil'
|
||||
|
||||
const mockBringNodeToFront = vi.fn()
|
||||
|
||||
@@ -191,3 +233,387 @@ describe('getWidgetIdForNode', () => {
|
||||
expect(getWidgetIdForNode(node, { name: 'x' })).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('media helpers', () => {
|
||||
it('classifies preview media nodes', () => {
|
||||
expect(isImageNode(undefined)).toBe(false)
|
||||
expect(isVideoNode(undefined)).toBe(false)
|
||||
expect(isAudioNode(undefined)).toBe(false)
|
||||
|
||||
const imageNode = new LGraphNode('Image')
|
||||
imageNode.previewMediaType = 'image'
|
||||
const imageWithImgs = Object.assign(new LGraphNode('Image'), {
|
||||
previewMediaType: 'model' as const,
|
||||
imgs: [document.createElement('img')]
|
||||
})
|
||||
const videoWithImgs = Object.assign(new LGraphNode('Video'), {
|
||||
previewMediaType: 'video' as const,
|
||||
imgs: [document.createElement('img')]
|
||||
})
|
||||
const videoNode = new LGraphNode('Video')
|
||||
videoNode.previewMediaType = 'video'
|
||||
const videoContainerNode = Object.assign(new LGraphNode('Video'), {
|
||||
videoContainer: document.body
|
||||
})
|
||||
const audioNode = new LGraphNode('Audio')
|
||||
audioNode.previewMediaType = 'audio'
|
||||
|
||||
expect(isImageNode(imageNode)).toBe(true)
|
||||
expect(isImageNode(imageWithImgs)).toBe(true)
|
||||
expect(isImageNode(videoWithImgs)).toBe(false)
|
||||
expect(isVideoNode(videoNode)).toBe(true)
|
||||
expect(isVideoNode(videoContainerNode)).toBe(true)
|
||||
expect(isAudioNode(audioNode)).toBe(true)
|
||||
})
|
||||
|
||||
it('distinguishes animated images from video outputs', () => {
|
||||
expect(isAnimatedOutput(undefined)).toBe(false)
|
||||
expect(isVideoOutput(undefined)).toBe(false)
|
||||
expect(isAnimatedOutput({ animated: [false, true] })).toBe(true)
|
||||
expect(isVideoOutput({ animated: [true] })).toBe(true)
|
||||
expect(
|
||||
isVideoOutput({
|
||||
animated: [true],
|
||||
images: [{ filename: 'clip.mp4' }]
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVideoOutput({
|
||||
animated: [true],
|
||||
images: [{ filename: 'preview.webp' }]
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isVideoOutput({
|
||||
animated: [true],
|
||||
images: [{ filename: 'preview.png' }]
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('detects 3d loader nodes', () => {
|
||||
const modelNode = new LGraphNode('Load3D')
|
||||
modelNode.type = 'Load3D'
|
||||
const animationNode = new LGraphNode('Load3DAnimation')
|
||||
animationNode.type = 'Load3DAnimation'
|
||||
const imageNode = new LGraphNode('LoadImage')
|
||||
imageNode.type = 'LoadImage'
|
||||
|
||||
expect(isLoad3dNode(modelNode)).toBe(true)
|
||||
expect(isLoad3dNode(animationNode)).toBe(true)
|
||||
expect(isLoad3dNode(imageNode)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('combo widget helpers', () => {
|
||||
function combo(values: IComboWidget['options']['values']): IComboWidget {
|
||||
return fromPartial<IComboWidget>({
|
||||
name: 'mode',
|
||||
type: 'combo',
|
||||
value: 'a',
|
||||
options: { values }
|
||||
})
|
||||
}
|
||||
|
||||
it('resolves combo values from arrays, records, functions, and missing options', () => {
|
||||
expect(resolveComboValues(combo(['a', 'b']))).toEqual(['a', 'b'])
|
||||
expect(resolveComboValues(combo({ a: 'A', b: 'B' }))).toEqual(['a', 'b'])
|
||||
expect(resolveComboValues(combo(() => ['x']))).toEqual(['x'])
|
||||
expect(
|
||||
resolveComboValues(fromPartial<IComboWidget>({ options: {} }))
|
||||
).toEqual([])
|
||||
})
|
||||
|
||||
it('adds only missing array combo values', () => {
|
||||
const widget = combo(['a'])
|
||||
|
||||
addToComboValues(widget, 'b')
|
||||
addToComboValues(widget, 'b')
|
||||
|
||||
expect(widget.options.values).toEqual(['a', 'b'])
|
||||
})
|
||||
|
||||
it('initializes missing combo options before adding values', () => {
|
||||
const missingOptions = fromPartial<IComboWidget>({})
|
||||
const missingValues = fromPartial<IComboWidget>({ options: {} })
|
||||
|
||||
addToComboValues(missingOptions, 'first')
|
||||
addToComboValues(missingValues, 'second')
|
||||
|
||||
expect(missingOptions.options.values).toEqual(['first'])
|
||||
expect(missingValues.options.values).toEqual(['second'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('node utility helpers', () => {
|
||||
it('classifies litegraph canvas item types', () => {
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('Node')
|
||||
const group = new LGraphGroup('Group')
|
||||
const reroute = new Reroute(toRerouteId(1), graph)
|
||||
|
||||
expect(isLGraphNode(node)).toBe(true)
|
||||
expect(isLGraphNode(group)).toBe(false)
|
||||
expect(isLGraphGroup(group)).toBe(true)
|
||||
expect(isLGraphGroup(node)).toBe(false)
|
||||
expect(isReroute(reroute)).toBe(true)
|
||||
expect(isReroute(node)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns a shared color option only when all colorable items match', () => {
|
||||
const red = { getColorOption: () => 'red', setColorOption: vi.fn() }
|
||||
const redAgain = { getColorOption: () => 'red', setColorOption: vi.fn() }
|
||||
const blue = { getColorOption: () => 'blue', setColorOption: vi.fn() }
|
||||
|
||||
expect(getItemsColorOption([red, redAgain, {}])).toBe('red')
|
||||
expect(getItemsColorOption([red, blue])).toBeNull()
|
||||
expect(getItemsColorOption([{}])).toBeNull()
|
||||
})
|
||||
|
||||
it('executes matching callbacks on node widgets', () => {
|
||||
const onRemove = vi.fn()
|
||||
const afterQueued = vi.fn()
|
||||
const node = new LGraphNode('Callbacks')
|
||||
node.widgets = [
|
||||
fromPartial<IBaseWidget>({ onRemove }),
|
||||
fromPartial<IBaseWidget>({ afterQueued })
|
||||
]
|
||||
|
||||
executeWidgetsCallback([node], 'onRemove')
|
||||
|
||||
expect(onRemove).toHaveBeenCalledOnce()
|
||||
expect(afterQueued).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns configured link colors with the default fallback', () => {
|
||||
expect(getLinkTypeColor('missing-type')).toBe(LiteGraph.LINK_COLOR)
|
||||
})
|
||||
})
|
||||
|
||||
describe('legacy workflow migration helpers', () => {
|
||||
it('repairs root and subgraph link input slots from current input order', () => {
|
||||
const graph = new LGraph()
|
||||
const source = new LGraphNode('Source')
|
||||
source.id = toNodeId(1)
|
||||
const target = new LGraphNode('Target')
|
||||
target.id = toNodeId(2)
|
||||
target.inputs = [
|
||||
fromPartial({ name: 'unlinked', link: null }),
|
||||
fromPartial({ name: 'missing', link: toLinkId(99) }),
|
||||
fromPartial({ name: 'linked', link: toLinkId(7) })
|
||||
]
|
||||
const link = new LLink(toLinkId(7), 'STRING', source.id, 0, target.id, 10)
|
||||
graph.add(source)
|
||||
graph.add(target)
|
||||
graph.links.set(link.id, link)
|
||||
|
||||
const subgraph = createTestSubgraph({ nodeCount: 0 })
|
||||
const innerSource = new LGraphNode('InnerSource')
|
||||
innerSource.id = toNodeId(3)
|
||||
const innerTarget = new LGraphNode('InnerTarget')
|
||||
innerTarget.id = toNodeId(4)
|
||||
innerTarget.inputs = [
|
||||
fromPartial({ name: 'inner-unlinked', link: null }),
|
||||
fromPartial({ name: 'inner-linked', link: toLinkId(8) })
|
||||
]
|
||||
const innerLink = new LLink(
|
||||
toLinkId(8),
|
||||
'STRING',
|
||||
innerSource.id,
|
||||
0,
|
||||
innerTarget.id,
|
||||
12
|
||||
)
|
||||
subgraph.add(innerSource)
|
||||
subgraph.add(innerTarget)
|
||||
subgraph.links.set(innerLink.id, innerLink)
|
||||
|
||||
const host = new LGraphNode('Host')
|
||||
host.id = toNodeId(5)
|
||||
vi.spyOn(host, 'isSubgraphNode').mockReturnValue(true)
|
||||
Object.assign(host, { subgraph })
|
||||
graph.add(host)
|
||||
|
||||
fixLinkInputSlots(graph)
|
||||
|
||||
expect(link.target_slot).toBe(2)
|
||||
expect(innerLink.target_slot).toBe(1)
|
||||
})
|
||||
|
||||
it('drops legacy force-input widget values only when lengths match', () => {
|
||||
const inputDefs = {
|
||||
seed: { name: 'seed', type: 'INT', forceInput: true },
|
||||
mode: { name: 'mode', type: 'STRING' },
|
||||
batch: {
|
||||
name: 'batch',
|
||||
type: 'INT',
|
||||
control_after_generate: true
|
||||
}
|
||||
}
|
||||
const widgets = [
|
||||
fromPartial<IBaseWidget>({ name: 'mode' }),
|
||||
fromPartial<IBaseWidget>({ name: 'batch' })
|
||||
]
|
||||
|
||||
expect(migrateWidgetsValues(inputDefs, widgets, [1, 2, 3, 4])).toEqual([
|
||||
2, 3, 4
|
||||
])
|
||||
expect(migrateWidgetsValues(inputDefs, widgets, [1, 2])).toEqual([1, 2])
|
||||
})
|
||||
|
||||
it('compresses root and subgraph widget input slots', () => {
|
||||
const graph = fromPartial<ISerialisedGraph>({
|
||||
nodes: [
|
||||
{
|
||||
id: 1,
|
||||
type: 'Node',
|
||||
inputs: [
|
||||
{
|
||||
name: 'widget',
|
||||
type: 'STRING',
|
||||
link: null,
|
||||
widget: { name: 'w' }
|
||||
},
|
||||
{ name: 'kept', type: 'STRING', link: 7 }
|
||||
]
|
||||
}
|
||||
],
|
||||
links: [[7, 2, 0, 1, 99, 'STRING']],
|
||||
definitions: {
|
||||
subgraphs: [
|
||||
{
|
||||
name: 'Subgraph',
|
||||
nodes: [
|
||||
{
|
||||
id: 3,
|
||||
type: 'Inner',
|
||||
inputs: [
|
||||
{
|
||||
name: 'legacy',
|
||||
type: 'STRING',
|
||||
link: null,
|
||||
widget: { name: 'legacy' }
|
||||
},
|
||||
{ name: 'inner', type: 'STRING', link: 8 }
|
||||
]
|
||||
}
|
||||
],
|
||||
links: [
|
||||
{
|
||||
id: 8,
|
||||
origin_id: 4,
|
||||
origin_slot: 0,
|
||||
target_id: 3,
|
||||
target_slot: 42,
|
||||
type: 'STRING'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
compressWidgetInputSlots(graph)
|
||||
|
||||
expect(graph.nodes[0].inputs?.map((input) => input.name)).toEqual(['kept'])
|
||||
expect(graph.links[0][4]).toBe(0)
|
||||
const subgraph = graph.definitions?.subgraphs?.[0]
|
||||
expect(subgraph?.nodes?.[0].inputs?.map((input) => input.name)).toEqual([
|
||||
'inner'
|
||||
])
|
||||
expect(subgraph?.links?.[0].target_slot).toBe(0)
|
||||
})
|
||||
|
||||
it('keeps labeled widget inputs and tolerates missing links', () => {
|
||||
const graph = fromPartial<ISerialisedGraph>({
|
||||
nodes: [
|
||||
{
|
||||
id: 1,
|
||||
type: 'Node',
|
||||
inputs: [
|
||||
{
|
||||
name: 'labeled',
|
||||
type: 'STRING',
|
||||
link: null,
|
||||
label: 'Shown',
|
||||
widget: { name: 'shown' }
|
||||
},
|
||||
{ name: 'stale', type: 'STRING', link: 99 }
|
||||
]
|
||||
},
|
||||
{ id: 2, type: 'NoInputs' }
|
||||
],
|
||||
links: []
|
||||
})
|
||||
|
||||
compressWidgetInputSlots(graph)
|
||||
|
||||
expect(graph.nodes[0].inputs?.map((input) => input.name)).toEqual([
|
||||
'labeled',
|
||||
'stale'
|
||||
])
|
||||
})
|
||||
|
||||
it('handles subgraphs without nodes or links and detects cycles', () => {
|
||||
const cyclic = fromPartial<ISerialisedGraph>({
|
||||
nodes: [],
|
||||
links: [],
|
||||
definitions: { subgraphs: [] }
|
||||
})
|
||||
const child = fromPartial<ExportedSubgraph>({
|
||||
name: 'child',
|
||||
nodes: [{ id: 1, type: 'Inner' }]
|
||||
})
|
||||
cyclic.definitions?.subgraphs?.push(child)
|
||||
|
||||
expect(() => compressWidgetInputSlots(cyclic)).not.toThrow()
|
||||
|
||||
const loop = fromPartial<ExportedSubgraph>({
|
||||
name: 'loop',
|
||||
nodes: [],
|
||||
definitions: { subgraphs: [] }
|
||||
})
|
||||
loop.definitions?.subgraphs?.push(loop)
|
||||
const graph = fromPartial<ISerialisedGraph>({
|
||||
nodes: [],
|
||||
links: [],
|
||||
definitions: { subgraphs: [loop] }
|
||||
})
|
||||
|
||||
expect(() => compressWidgetInputSlots(graph)).toThrow(
|
||||
'Infinite loop detected'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('resolveNodeWidget', () => {
|
||||
it('resolves root graph nodes and widgets', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
const graph = new LGraph()
|
||||
const node = new LGraphNode('TestNode')
|
||||
const widget = node.addWidget('text', 'prompt', 'hello', () => {})
|
||||
graph.add(node)
|
||||
|
||||
expect(resolveNodeWidget(node.id, undefined, graph)).toEqual([node])
|
||||
expect(resolveNodeWidget(node.id, 'prompt', graph)).toEqual([node, widget])
|
||||
expect(resolveNodeWidget(node.id, 'missing', graph)).toEqual([])
|
||||
expect(resolveNodeWidget('not-a-node-id', 'prompt', graph)).toEqual([])
|
||||
})
|
||||
|
||||
it('resolves widgets exposed by subgraph host locators', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
const graph = new LGraph()
|
||||
const host = new LGraphNode('Host')
|
||||
host.id = toNodeId(8)
|
||||
const widget = host.addWidget('text', 'mode', 'fast', () => {})
|
||||
graph.add(host)
|
||||
vi.spyOn(host, 'isSubgraphNode').mockReturnValue(true)
|
||||
const locator = createNodeLocatorId(
|
||||
'00000000-0000-0000-0000-000000000001',
|
||||
host.id
|
||||
)
|
||||
|
||||
expect(resolveNodeWidget(locator, 'mode', graph)).toEqual([host, widget])
|
||||
expect(resolveNodeWidget(locator, 'missing', graph)).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
15
src/utils/loaderNodeUtil.test.ts
Normal file
15
src/utils/loaderNodeUtil.test.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { detectNodeTypeFromFilename } from './loaderNodeUtil'
|
||||
|
||||
describe('detectNodeTypeFromFilename', () => {
|
||||
it.for([
|
||||
['image.png', { nodeType: 'LoadImage', widgetName: 'image' }],
|
||||
['clip.mp4', { nodeType: 'LoadVideo', widgetName: 'file' }],
|
||||
['sound.mp3', { nodeType: 'LoadAudio', widgetName: 'audio' }],
|
||||
['mesh.glb', { nodeType: null, widgetName: null }],
|
||||
['notes.txt', { nodeType: null, widgetName: null }]
|
||||
] as const)('maps %s to its loader node', ([filename, expected]) => {
|
||||
expect(detectNodeTypeFromFilename(filename)).toEqual(expected)
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user